llama.py 28.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
gaoqiong's avatar
gaoqiong committed
30
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig, LoRAConfig
34
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
35
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
45
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
46
from vllm.model_executor.layers.rotary_embedding import get_rope
47
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import (
51
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
52
from vllm.model_executor.sampling_metadata import SamplingMetadata
53
from vllm.sequence import IntermediateTensors
54
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
from .interfaces import SupportsLoRA
57
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
58

gaoqiong's avatar
gaoqiong committed
59
from vllm import _custom_ops as ops
60
61
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
62
63

class LlamaMLP(nn.Module):
64

Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
    def __init__(
        self,
67
68
69
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
70
        quant_config: Optional[QuantizationConfig] = None,
71
        bias: bool = False,
72
        prefix: str = "",
73
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        super().__init__()
75
        self.gate_up_proj = MergedColumnParallelLinear(
76
77
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
78
            bias=bias,
79
80
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
81
82
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
83
                                           bias=bias,
84
85
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
86
87
88
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91

    def forward(self, x):
92
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
98
99
100
101
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
102
        config: LlamaConfig,
103
104
105
106
107
108
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
109
        quant_config: Optional[QuantizationConfig] = None,
110
        bias: bool = False,
111
        cache_config: Optional[CacheConfig] = None,
112
        prefix: str = "",
113
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        super().__init__()
115
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
116
        tp_size = get_tensor_model_parallel_world_size()
117
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
118
119
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
120
        self.total_num_kv_heads = num_kv_heads
121
122
123
124
125
126
127
128
129
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
130
131
132
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
133
134
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
135
        self.scaling = self.head_dim**-0.5
136
137
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
138

139
        self.qkv_proj = QKVParallelLinear(
140
141
142
143
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
144
            bias=bias,
145
            quant_config=quant_config,
146
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
147
        )
148

149
        self.o_proj = RowParallelLinear(
150
151
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
152
            bias=bias,
153
            quant_config=quant_config,
154
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
155
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
156

157
158
159
160
        is_neox_style = True
        if quant_config is not None and quant_config.get_name() == "gguf":
            is_neox_style = False

161
162
163
164
165
166
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
167
            is_neox_style=is_neox_style,
168
        )
169
170
171
172
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
173
174
                              cache_config=cache_config,
                              quant_config=quant_config)
175
176
177
178
179
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182

    def forward(
        self,
183
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
184
        hidden_states: torch.Tensor,
185
186
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
187
    ) -> torch.Tensor:
188
        qkv, _ = self.qkv_proj(hidden_states)
189
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
190
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
191
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
192
        q, k = self.rotary_emb(positions, q, k)
193
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196
197
198
199
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

200
201
202
    def __init__(
        self,
        config: LlamaConfig,
203
        cache_config: Optional[CacheConfig] = None,
204
        quant_config: Optional[QuantizationConfig] = None,
205
        prefix: str = "",
206
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
        super().__init__()
        self.hidden_size = config.hidden_size
209
210
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
211
212
213
214
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
215
216
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
217
218
219
220
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
221
        self.self_attn = LlamaAttention(
222
            config=config,
223
224
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
225
226
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
227
228
229
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
230
            quant_config=quant_config,
231
            bias=attention_bias,
232
            cache_config=cache_config,
233
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
        )
        self.mlp = LlamaMLP(
236
237
238
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
239
            quant_config=quant_config,
240
            bias=getattr(config, "mlp_bias", False),
241
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
242
        )
243
244
245
246
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249

    def forward(
        self,
250
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        hidden_states: torch.Tensor,
252
253
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
254
255
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
256
        # Self Attention
257
258
259
260
261
262
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
267
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
        )

        # Fully Connected
271
272
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
273
        hidden_states = self.mlp(hidden_states)
274
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
277
278


class LlamaModel(nn.Module):

279
280
281
    def __init__(
        self,
        config: LlamaConfig,
282
        cache_config: Optional[CacheConfig] = None,
283
        quant_config: Optional[QuantizationConfig] = None,
284
        lora_config: Optional[LoRAConfig] = None,
285
        prefix: str = "",
286
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
290
291
292
293
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
294
295
296
297
298
299
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
300
                quant_config=quant_config,
301
302
303
            )
        else:
            self.embed_tokens = PPMissingLayer()
304
        self.start_layer, self.end_layer, self.layers = make_layers(
305
            config.num_hidden_layers,
306
307
308
309
310
            lambda prefix: LlamaDecoderLayer(config=config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
                                             prefix=prefix),
            prefix=f"{prefix}.layers")
311
312
313
314
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
315

316
317
318
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
319
320
    def forward(
        self,
321
        input_ids: Optional[torch.Tensor],
322
        positions: torch.Tensor,
323
324
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
325
        intermediate_tensors: Optional[IntermediateTensors],
326
        inputs_embeds: Optional[torch.Tensor] = None,
327
328
329
330
331
332
333
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
334
        else:
335
336
337
338
339
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
340
            layer = self.layers[i]
341
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
                positions,
                hidden_states,
344
                kv_caches[i - self.start_layer],
345
                attn_metadata,
346
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
347
            )
348
349
350
351
352
353
354

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

355
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
        return hidden_states


359
class LlamaForCausalLM(nn.Module, SupportsLoRA):
Terry's avatar
Terry committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
374
375
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
376
377
378
379
380
381
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
382
383
384
385
386
387
388
389
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
409

410
411
412
    def __init__(
        self,
        config: LlamaConfig,
413
        cache_config: Optional[CacheConfig] = None,
414
        quant_config: Optional[QuantizationConfig] = None,
415
        lora_config: Optional[LoRAConfig] = None,
416
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
417
        super().__init__()
418

Woosuk Kwon's avatar
Woosuk Kwon committed
419
        self.config = config
420
421
        self.lora_config = lora_config

422
423
424
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
425
426
                                lora_config=lora_config,
                                prefix="model")
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
                quant_config=quant_config,
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
            self.sampler = Sampler()
        else:
            self.lm_head = PPMissingLayer()
451
452
            
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
453
454
455
456
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
457
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
458
459
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
460
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
461
462
463

    def forward(
        self,
464
465
        input_ids: torch.Tensor,
        positions: torch.Tensor,
466
467
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
468
469
470
471
472
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
                                  attn_metadata, intermediate_tensors)
        return model_output
473

474
475
476
477
478
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
479
        logits = self.logits_processor(self.lm_head, hidden_states,
480
481
482
                                       sampling_metadata)
        return logits

483
484
    def sample(
        self,
485
        logits: torch.Tensor,
486
        sampling_metadata: SamplingMetadata,
487
    ) -> Optional[SamplerOutput]:
488
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
489
490
        return next_tokens

491
492
493
494
495
496
497
498
499
500
501
502
503
504
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

505
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
506
507
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
508
509
510
511
512
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
513
        ]
514
        params_dict = dict(self.named_parameters())
515
        for name, loaded_weight in weights:
516
517
            name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)

518
519
            if "rotary_emb.inv_freq" in name:
                continue
520
521
522
523
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
524
                continue
525
526
527
528
529
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
530
531
532
533
534
535
536
537
            if scale_name := get_compressed_tensors_cache_scale(name):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                continue
538
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
539
                if weight_name not in name:
540
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
541
542
543
544
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
545
546
547
548

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
549
                param = params_dict[name]
550
551
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
552

553
                break
554
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
555
556
557
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
558
                # Remapping the name of FP8 kv-scale.
559
560
561
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
562
563
564
565

                if is_pp_missing_parameter(name, self):
                    continue

566
567
568
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
gaoqiong's avatar
gaoqiong committed
569
570
                weight_loader(param, loaded_weight)  
            
571
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
572
573
574
575
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
576
577
                "mlp.down_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
578
            ]
579
580
581
582
            
            if self.use_lm_nn:
                lay_key_words.append("lm_head.weight")

gaoqiong's avatar
gaoqiong committed
583
584
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
585
586
587
            lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            qkv_words = "|".join(lay_qkv_words)          
            
gaoqiong's avatar
gaoqiong committed
588
            for layername, weight in params_dict.items():
589
590
591
592
                if "lm_head.weight" in layername:
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
gaoqiong's avatar
gaoqiong committed
593
                matches = re.findall(combined_words, layername)
594
595
596
597
                if matches:         
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
598
599
600
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
601
                                 
gaoqiong's avatar
gaoqiong committed
602
603
604
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
605
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
606
607
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
608
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
634
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
635
                    
gaoqiong's avatar
gaoqiong committed
636
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
637
638
639
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
640
                    
gaoqiong's avatar
gaoqiong committed
641
642
643
644
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
645
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
646
647
648
649
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
            
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
                if matches:
                    weight_data =params_dict[layername]
                    k=weight_data.shape[0]
                    _weight=weight_data.T.contiguous().reshape(k,-1)
                    weight_data.data.copy_(_weight)   
             
668
669
670
671
672
673
674
675
676
677
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
678
679
            if not isinstance(self.model.layers[layer_idx], nn.Identity):
                layer_self_attn = self.model.layers[layer_idx].self_attn
680
681
682
683
684
685
686
687

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
688
                layer_self_attn.attn._kv_scale = scaling_factor
689
690
691
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
            self, name: str,
            loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:

        def permute(w, n_heads):
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight