llama.py 25.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
gaoqiong's avatar
gaoqiong committed
30
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig, LoRAConfig
34
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
35
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
45
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
46
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
47
from vllm.model_executor.layers.sampler import Sampler
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import (
51
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
52
from vllm.model_executor.sampling_metadata import SamplingMetadata
53
from vllm.sequence import IntermediateTensors, SamplerOutput
54
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
from .interfaces import SupportsLoRA
57
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
58

gaoqiong's avatar
gaoqiong committed
59
from vllm import _custom_ops as ops
60
61
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
62
63

class LlamaMLP(nn.Module):
64

Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
    def __init__(
        self,
67
68
69
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
70
        quant_config: Optional[QuantizationConfig] = None,
71
        bias: bool = False,
72
        prefix: str = "",
73
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        super().__init__()
75
        self.gate_up_proj = MergedColumnParallelLinear(
76
77
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
78
            bias=bias,
79
80
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
81
82
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
83
                                           bias=bias,
84
85
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
86
87
88
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
89
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91

    def forward(self, x):
92
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
98
99
100
101
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
102
        config: LlamaConfig,
103
104
105
106
107
108
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
109
        quant_config: Optional[QuantizationConfig] = None,
110
        bias: bool = False,
111
        cache_config: Optional[CacheConfig] = None,
112
        prefix: str = "",
113
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        super().__init__()
115
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
116
        tp_size = get_tensor_model_parallel_world_size()
117
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
118
119
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
120
        self.total_num_kv_heads = num_kv_heads
121
122
123
124
125
126
127
128
129
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
130
131
132
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
133
134
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
135
        self.scaling = self.head_dim**-0.5
136
137
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
138

139
        self.qkv_proj = QKVParallelLinear(
140
141
142
143
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
144
            bias=bias,
145
            quant_config=quant_config,
146
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
147
        )
148
        self.o_proj = RowParallelLinear(
149
150
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
151
            bias=bias,
152
            quant_config=quant_config,
153
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
155

156
157
158
159
160
161
162
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
163
164
165
166
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
167
168
                              cache_config=cache_config,
                              quant_config=quant_config)
169
170
171
172
173
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176

    def forward(
        self,
177
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
178
        hidden_states: torch.Tensor,
179
180
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
181
    ) -> torch.Tensor:
182
        qkv, _ = self.qkv_proj(hidden_states)
183
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
184
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
185
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
186
        q, k = self.rotary_emb(positions, q, k)
187
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
193
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

194
195
196
    def __init__(
        self,
        config: LlamaConfig,
197
        cache_config: Optional[CacheConfig] = None,
198
        quant_config: Optional[QuantizationConfig] = None,
199
        prefix: str = "",
200
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
        super().__init__()
        self.hidden_size = config.hidden_size
203
204
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
205
206
207
208
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
209
210
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
211
212
213
214
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
215
        self.self_attn = LlamaAttention(
216
            config=config,
217
218
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
219
220
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
221
222
223
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
224
            quant_config=quant_config,
225
            bias=attention_bias,
226
            cache_config=cache_config,
227
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
        )
        self.mlp = LlamaMLP(
230
231
232
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
233
            quant_config=quant_config,
234
            bias=getattr(config, "mlp_bias", False),
235
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
236
        )
237
238
239
240
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
243

    def forward(
        self,
244
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
245
        hidden_states: torch.Tensor,
246
247
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
248
249
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
250
        # Self Attention
251
252
253
254
255
256
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
261
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
        )

        # Fully Connected
265
266
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
267
        hidden_states = self.mlp(hidden_states)
268
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272


class LlamaModel(nn.Module):

273
274
275
    def __init__(
        self,
        config: LlamaConfig,
276
        cache_config: Optional[CacheConfig] = None,
277
        quant_config: Optional[QuantizationConfig] = None,
278
        lora_config: Optional[LoRAConfig] = None,
279
        prefix: str = "",
280
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
284
285
286
287
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
288
289
290
291
292
293
294
295
296
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()
297
        self.start_layer, self.end_layer, self.layers = make_layers(
298
            config.num_hidden_layers,
299
300
301
302
303
            lambda prefix: LlamaDecoderLayer(config=config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
                                             prefix=prefix),
            prefix=f"{prefix}.layers")
304
305
306
307
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
308

309
310
311
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
312
313
    def forward(
        self,
314
        input_ids: Optional[torch.Tensor],
315
        positions: torch.Tensor,
316
317
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
318
        intermediate_tensors: Optional[IntermediateTensors],
319
        inputs_embeds: Optional[torch.Tensor] = None,
320
321
322
323
324
325
326
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
327
        else:
328
329
330
331
332
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
333
            layer = self.layers[i]
334
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
                positions,
                hidden_states,
337
                kv_caches[i - self.start_layer],
338
                attn_metadata,
339
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
340
            )
341
342
343
344
345
346
347

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

348
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
349
350
351
        return hidden_states


352
class LlamaForCausalLM(nn.Module, SupportsLoRA):
Terry's avatar
Terry committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
367
368
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
369
370
371
372
373
374
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
375
376
377
378
379
380
381
382
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
383

384
385
386
    def __init__(
        self,
        config: LlamaConfig,
387
        cache_config: Optional[CacheConfig] = None,
388
        quant_config: Optional[QuantizationConfig] = None,
389
        lora_config: Optional[LoRAConfig] = None,
390
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
391
        super().__init__()
392

Woosuk Kwon's avatar
Woosuk Kwon committed
393
        self.config = config
394
395
        self.lora_config = lora_config

396
397
398
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
399
400
                                lora_config=lora_config,
                                prefix="model")
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
                quant_config=quant_config,
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
            self.sampler = Sampler()
        else:
            self.lm_head = PPMissingLayer()
425
426
            
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
427
428
429
430
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
431
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
432
433
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
434
435
436

    def forward(
        self,
437
438
        input_ids: torch.Tensor,
        positions: torch.Tensor,
439
440
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
441
442
443
444
445
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
                                  attn_metadata, intermediate_tensors)
        return model_output
446

447
448
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
449
        logits = self.logits_processor(self.lm_head, hidden_states,
450
451
452
                                       sampling_metadata)
        return logits

453
454
    def sample(
        self,
455
        logits: torch.Tensor,
456
        sampling_metadata: SamplingMetadata,
457
    ) -> Optional[SamplerOutput]:
458
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
459
460
        return next_tokens

461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

475
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
476
477
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
478
479
480
481
482
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
483
        ]
484
        params_dict = dict(self.named_parameters())
485
        for name, loaded_weight in weights:
486
487
            if "rotary_emb.inv_freq" in name:
                continue
488
489
490
491
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
492
                continue
493
494
495
496
497
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
498
499
500
501
502
503
504
505
            if scale_name := get_compressed_tensors_cache_scale(name):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                continue
506
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
507
                if weight_name not in name:
508
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
509
510
511
512
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
513
514
515
516

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
517
                param = params_dict[name]
518
519
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
520

521
                break
522
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
523
524
525
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
526
                # Remapping the name of FP8 kv-scale.
527
528
529
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
530
531
532
533

                if is_pp_missing_parameter(name, self):
                    continue

534
535
536
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
gaoqiong's avatar
gaoqiong committed
537
538
                weight_loader(param, loaded_weight)  
            
539
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
540
541
542
543
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
544
545
                "mlp.down_proj.weight",
                "lm_head.weight"
gaoqiong's avatar
gaoqiong committed
546
547
548
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
549
550
551
            lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            qkv_words = "|".join(lay_qkv_words)          
            
gaoqiong's avatar
gaoqiong committed
552
553
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
554
555
556
557
                if matches:         
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
558
559
560
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
561
                                 
gaoqiong's avatar
gaoqiong committed
562
563
564
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
565
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
566
567
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
568
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
594
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
595
                    
gaoqiong's avatar
gaoqiong committed
596
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
597
598
599
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
600
                    
gaoqiong's avatar
gaoqiong committed
601
602
603
604
605
606
607
608
609
610
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
                         
611
612
613
614
615
616
617
618
619
620
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
621
622
            if not isinstance(self.model.layers[layer_idx], nn.Identity):
                layer_self_attn = self.model.layers[layer_idx].self_attn
623
624
625
626
627
628
629
630

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
631
                layer_self_attn.attn._kv_scale = scaling_factor
632
633
634
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")