llama.py 27.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
24
"""Inference-only LLaMA model compatible with HuggingFace weights."""
25
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30

import torch
from torch import nn
from transformers import LlamaConfig

zhuwenwen's avatar
zhuwenwen committed
31
import os
gaoqiong's avatar
gaoqiong committed
32
import re
33
import vllm.envs as envs
34
from vllm.attention import Attention
zhuwenwen's avatar
zhuwenwen committed
35

36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
42
43
                                               QKVParallelLinear,
                                               RowParallelLinear)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import (
50
    default_weight_loader, maybe_remap_kv_scale_name)
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsLoRA, SupportsPP
55
56
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

gaoqiong's avatar
gaoqiong committed
60
from vllm import _custom_ops as ops
61
62
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
63
64

class LlamaMLP(nn.Module):
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
    def __init__(
        self,
68
69
70
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
71
        quant_config: Optional[QuantizationConfig] = None,
72
        bias: bool = False,
73
        prefix: str = "",
74
        reduce_results: bool = True,
75
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        super().__init__()
77
        self.gate_up_proj = MergedColumnParallelLinear(
78
79
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
80
            bias=bias,
81
            quant_config=quant_config,
82
83
84
85
86
87
88
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
89
            reduce_results=reduce_results,
90
91
            prefix=f"{prefix}.down_proj",
        )
92
93
94
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97

    def forward(self, x):
98
99
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103
104
105
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

106
107
108
109
110
111
112
113
114
115
    def __init__(self,
                 config: LlamaConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 rope_theta: float = 10000,
                 rope_scaling: Optional[Dict[str, Any]] = None,
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
116
                 bias_o_proj: bool = False,
117
                 cache_config: Optional[CacheConfig] = None,
118
                 prefix: str = "") -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        super().__init__()
120
        layer_idx = extract_layer_index(prefix)
121
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
122
        tp_size = get_tensor_model_parallel_world_size()
123
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
124
125
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
126
        self.total_num_kv_heads = num_kv_heads
127
128
129
130
131
132
133
134
135
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
136
137
138
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Amit Garg's avatar
Amit Garg committed
139
        # Phi models introduced a partial_rotary_factor parameter in the config
140
141
        self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
                                             1)
Zhuohan Li's avatar
Zhuohan Li committed
142
143
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
144
        self.scaling = self.head_dim**-0.5
145
146
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
147

148
        self.qkv_proj = QKVParallelLinear(
149
150
151
152
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
153
            bias=bias,
154
            quant_config=quant_config,
155
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        )
157

158
        self.o_proj = RowParallelLinear(
159
160
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
161
            bias=bias_o_proj,
162
            quant_config=quant_config,
163
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
165

166
        is_neox_style = True
167
168
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
169
170
            is_neox_style = False

171
172
        self.rotary_emb = get_rope(
            self.head_dim,
173
            rotary_dim=self.head_dim,
174
175
176
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
177
            is_neox_style=is_neox_style,
178
            partial_rotary_factor=self.partial_rotary_factor,
179
        )
180
181

        if hasattr(config, "interleaved_sliding_window"):
182
183
184
185
186
187
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
188
            else:
189
190
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
191
192
193
        else:
            sliding_window = None

194
195
196
197
198
199
200
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
201
            per_layer_sliding_window=sliding_window,
202
            prefix=f"{prefix}.attn",
203
        )
204
205
206
207
208
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211

    def forward(
        self,
212
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
215
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
216
217
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
218
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
219
        q, k = self.rotary_emb(positions, q, k)
220
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223
224
225
226
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

227
228
229
    def __init__(
        self,
        config: LlamaConfig,
230
        cache_config: Optional[CacheConfig] = None,
231
        quant_config: Optional[QuantizationConfig] = None,
232
        prefix: str = "",
233
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
        super().__init__()
        self.hidden_size = config.hidden_size
236
237
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
238
239
240
241
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
242
243
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
244
245
246
247
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
248
249
250
251
252
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

Woosuk Kwon's avatar
Woosuk Kwon committed
253
        self.self_attn = LlamaAttention(
254
            config=config,
255
256
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
257
258
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
259
260
261
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
262
            quant_config=quant_config,
263
            bias=attention_bias,
264
            bias_o_proj=bias_o_proj,
265
            cache_config=cache_config,
266
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
        )
        self.mlp = LlamaMLP(
269
270
271
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
272
            quant_config=quant_config,
273
            bias=getattr(config, "mlp_bias", False),
274
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
275
        )
276
277
278
279
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
282

    def forward(
        self,
283
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        hidden_states: torch.Tensor,
285
286
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
287
        # Self Attention
288
289
290
291
292
293
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
294
        hidden_states = self.self_attn(positions=positions,
295
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297

        # Fully Connected
298
299
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
300
        hidden_states = self.mlp(hidden_states)
301
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303


304
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
class LlamaModel(nn.Module):

307
308
309
310
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
311
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
312
        super().__init__()
313
314
315
316
317
318

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
319
        self.config = config
320
        self.quant_config = quant_config
321
322
323
324
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
325
326
327
328
329
330
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
331
                quant_config=quant_config,
332
333
334
            )
        else:
            self.embed_tokens = PPMissingLayer()
335
        self.start_layer, self.end_layer, self.layers = make_layers(
336
            config.num_hidden_layers,
337
338
339
340
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
341
342
            prefix=f"{prefix}.layers",
        )
343
344
345
346
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
347

348
349
        self.aux_hidden_state_layers: tuple[int] = tuple()

350
351
352
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
353
354
355
356
357
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
358

zhuwenwen's avatar
zhuwenwen committed
359
360
361
362
363
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
364

365
366
367
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
368
369
    def forward(
        self,
370
        input_ids: Optional[torch.Tensor],
371
        positions: torch.Tensor,
372
        intermediate_tensors: Optional[IntermediateTensors],
373
        inputs_embeds: Optional[torch.Tensor] = None,
374
375
    ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
                                                        list[torch.Tensor]]]:
376
377
378
379
380
381
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
382
        else:
383
384
385
386
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

387
388
389
390
391
        aux_hidden_states = []
        for idx, layer in enumerate(
                self.layers[self.start_layer:self.end_layer]):
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
392
            hidden_states, residual = layer(positions, hidden_states, residual)
393
394
395
396
397
398
399

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

400
        hidden_states, _ = self.norm(hidden_states, residual)
401
402
403

        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
Woosuk Kwon's avatar
Woosuk Kwon committed
404
405
        return hidden_states

406
407
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
408
409
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
410
411
412
413
414
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
415
        ]
416
        params_dict = dict(self.named_parameters())
417
        loaded_params: Set[str] = set()
418
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
419
420
421
            if self.use_llama_nn:
                current_count = loaded_weight.current_count 
                total_count = loaded_weight.total_count
422
423
            if "rotary_emb.inv_freq" in name:
                continue
424
425
426
427
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
428
                continue
429
430
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
431
                # Loading kv cache quantization scales
432
433
434
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
435
436
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
437
                weight_loader(param, loaded_weight)
438
                loaded_params.add(scale_name)
439
                continue
440
441
442
443
444
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
445
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
446
                if weight_name not in name:
447
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
448
449
450
451
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
452
453
454
455

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
456
                param = params_dict[name]
457
458
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
459
                break
460
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
461
462
463
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
464
465
466
467

                if is_pp_missing_parameter(name, self):
                    continue

468
469
470
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
471
                weight_loader(param, loaded_weight)
472
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
473
            
zhuwenwen's avatar
zhuwenwen committed
474
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
gaoqiong's avatar
gaoqiong committed
475
476
477
478
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
479
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
480
481
482
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
483
484
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
485
            
zhuwenwen's avatar
zhuwenwen committed
486
            # for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
487
488
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
489
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
490
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
491
492
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
493
494
495
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
496
                    
gaoqiong's avatar
gaoqiong committed
497
                matches = re.findall(combined_words, layername)
498
                
499
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
500
501
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
502
                        
zhuwenwen's avatar
zhuwenwen committed
503
504
505
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
506
                                 
gaoqiong's avatar
gaoqiong committed
507
508
509
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
510
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
511
512
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
513
                    weight.data=weight.data.reshape(ori_shape[1], -1)
514
        else:
zhuwenwen's avatar
zhuwenwen committed
515
            os.environ['LM_NN'] = '0'
516
            os.environ['LLAMA_NN'] = '0'
gaoqiong's avatar
gaoqiong committed
517
             
518
        return loaded_params
519

Woosuk Kwon's avatar
Woosuk Kwon committed
520

521
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
522
    packed_modules_mapping = {
523
524
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
525
526
527
528
529
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
530
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
531
532
    }
    embedding_padding_modules = ["lm_head"]
533

534
535
536
537
538
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
539
540
541
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
542
543
544
545
546
547
548
549
550
551
552
553
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
554
        "norm": "model.norm",
555
    }
556

557
558
559
560
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
561
                 layer_type: type[nn.Module] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
562
        super().__init__()
563
564
565
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
566
        self.config = config
567
        self.lora_config = lora_config
568
        self.model = self._init_model(vllm_config=vllm_config,
569
570
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
571

572
573
574
575
576
577
578
579
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
580
581
582
583
584
585
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
586
                quant_config=quant_config,
587
                prefix=maybe_prefix(prefix, "lm_head"),
588
589
            )
            if config.tie_word_embeddings:
590
591
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
592
593
594
595
596
597
598

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
599

600
601
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
602

603
604
605
606
607
608
609
    def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

610
611
612
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
613
                    layer_type: type[nn.Module] = LlamaDecoderLayer):
614
615
616
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
617

618
619
620
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
621
622
    def forward(
        self,
623
624
        input_ids: torch.Tensor,
        positions: torch.Tensor,
625
        intermediate_tensors: Optional[IntermediateTensors] = None,
626
        inputs_embeds: Optional[torch.Tensor] = None,
627
    ) -> Union[torch.Tensor, IntermediateTensors]:
628
        model_output = self.model(input_ids, positions, intermediate_tensors,
629
                                  inputs_embeds)
630
        return model_output
631

632
633
634
635
636
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
637
        logits = self.logits_processor(self.lm_head, hidden_states,
638
639
640
                                       sampling_metadata)
        return logits

641
642
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
643
644
645
646
647
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
648
        return loader.load_weights(
649
            self.maybe_remap_mistral(name, loaded_weight)
650
            for name, loaded_weight in weights)
651

652
653
654
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
655
656
657
658
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
659

660
        def permute(w: torch.Tensor, n_heads: int):
661
662
663
664
665
666
667
668
669
670
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
671
        if "wk" in modules and modules[-1] == "weight":
672
673
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
674
        elif "wq" in modules and modules[-1] == "weight":
675
676
677
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

678
679
680
681
682
683
684
685
686
687
688
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
689
690
691
                name = name.replace(item, mapping[item])

        return name, loaded_weight