llama.py 22.5 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
gaoqiong's avatar
gaoqiong committed
30
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig, LoRAConfig
34
35
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.model_executor.layers.sampler import Sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
49
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import SamplerOutput
52
from vllm.utils import is_hip, print_warning_once
Woosuk Kwon's avatar
Woosuk Kwon committed
53

gaoqiong's avatar
gaoqiong committed
54
from vllm import _custom_ops as ops
55
56
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
57
58

class LlamaMLP(nn.Module):
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
    def __init__(
        self,
62
63
64
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
65
        quant_config: Optional[QuantizationConfig] = None,
66
        bias: bool = False,
67
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        super().__init__()
69
        self.gate_up_proj = MergedColumnParallelLinear(
70
71
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
72
            bias=bias,
73
            quant_config=quant_config)
74
75
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
76
                                           bias=bias,
77
                                           quant_config=quant_config)
78
79
80
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83

    def forward(self, x):
84
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89
90
91
92
93
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
94
95
96
97
98
99
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
100
        quant_config: Optional[QuantizationConfig] = None,
101
        bias: bool = False,
102
        cache_config: Optional[CacheConfig] = None,
103
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        super().__init__()
105
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
106
        tp_size = get_tensor_model_parallel_world_size()
107
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
108
109
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
110
        self.total_num_kv_heads = num_kv_heads
111
112
113
114
115
116
117
118
119
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
120
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
121
122
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
123
        self.scaling = self.head_dim**-0.5
124
125
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
126

127
        self.qkv_proj = QKVParallelLinear(
128
129
130
131
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
132
            bias=bias,
133
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        )
135
        self.o_proj = RowParallelLinear(
136
137
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
138
            bias=bias,
139
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
140
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
141

142
143
144
145
146
147
148
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
149
150
151
152
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
153
154
                              cache_config=cache_config,
                              quant_config=quant_config)
155
156
157
158
159
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162

    def forward(
        self,
163
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        hidden_states: torch.Tensor,
165
166
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
167
    ) -> torch.Tensor:
168
        qkv, _ = self.qkv_proj(hidden_states)
169
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
170
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
171
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
172
        q, k = self.rotary_emb(positions, q, k)
173
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
179
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

180
181
182
    def __init__(
        self,
        config: LlamaConfig,
183
        cache_config: Optional[CacheConfig] = None,
184
        quant_config: Optional[QuantizationConfig] = None,
185
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
        super().__init__()
        self.hidden_size = config.hidden_size
188
189
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
190
191
192
193
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
194
195
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
196
197
198
199
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
        self.self_attn = LlamaAttention(
201
202
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
203
204
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
205
206
207
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
208
            quant_config=quant_config,
209
            bias=attention_bias,
210
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
        )
        self.mlp = LlamaMLP(
213
214
215
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
216
            quant_config=quant_config,
217
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
218
        )
219
220
221
222
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225

    def forward(
        self,
226
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
227
        hidden_states: torch.Tensor,
228
229
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
230
231
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
        # Self Attention
233
234
235
236
237
238
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
242
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
243
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
        )

        # Fully Connected
247
248
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
249
        hidden_states = self.mlp(hidden_states)
250
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254


class LlamaModel(nn.Module):

255
256
257
    def __init__(
        self,
        config: LlamaConfig,
258
        cache_config: Optional[CacheConfig] = None,
259
        quant_config: Optional[QuantizationConfig] = None,
260
        lora_config: Optional[LoRAConfig] = None,
261
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
265
266
267
268
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
269
        self.embed_tokens = VocabParallelEmbedding(
270
            self.vocab_size,
271
            config.hidden_size,
272
            org_num_embeddings=config.vocab_size,
273
        )
274
        self.layers = nn.ModuleList([
275
276
277
278
            LlamaDecoderLayer(config=config,
                              cache_config=cache_config,
                              quant_config=quant_config)
            for idx in range(config.num_hidden_layers)
279
        ])
280
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
281

282
283
284
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
    def forward(
        self,
287
        input_ids: Optional[torch.Tensor],
288
        positions: torch.Tensor,
289
290
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
291
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
292
    ) -> torch.Tensor:
293
294
295
296
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
297
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
        for i in range(len(self.layers)):
            layer = self.layers[i]
300
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
                positions,
                hidden_states,
                kv_caches[i],
304
                attn_metadata,
305
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
306
            )
307
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310
311
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
326
327
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
328
329
330
331
332
333
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
334
335
336
337
338
339
340
341
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
342

343
344
345
    def __init__(
        self,
        config: LlamaConfig,
346
        cache_config: Optional[CacheConfig] = None,
347
        quant_config: Optional[QuantizationConfig] = None,
348
        lora_config: Optional[LoRAConfig] = None,
349
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        super().__init__()
        self.config = config
352
353
354
355
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
356
        self.unpadded_vocab_size = config.vocab_size
357
        if lora_config:
Terry's avatar
Terry committed
358
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
359
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
360
            self.unpadded_vocab_size,
361
362
363
364
365
366
367
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
368
369
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
370
371
372
373
374

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
375
376
        
        self.quant_method = None
gaoqiong's avatar
gaoqiong committed
377
378
379
380
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
381
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
382
383
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
384
385
386

    def forward(
        self,
387
388
        input_ids: torch.Tensor,
        positions: torch.Tensor,
389
390
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
391
    ) -> torch.Tensor:
392
        hidden_states = self.model(input_ids, positions, kv_caches,
393
                                   attn_metadata)
394
395
        return hidden_states

396
397
398
399
400
401
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

402
403
    def sample(
        self,
404
        logits: torch.Tensor,
405
        sampling_metadata: SamplingMetadata,
406
    ) -> Optional[SamplerOutput]:
407
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
408
409
        return next_tokens

410
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
411
412
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
413
414
415
416
417
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
418
        ]
419
        params_dict = dict(self.named_parameters())
420
        for name, loaded_weight in weights:
421
422
            if "rotary_emb.inv_freq" in name:
                continue
423
424
425
426
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
427
                continue
428
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
429
                if weight_name not in name:
430
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
431
432
433
434
435
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
436
437
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
438
                break
439
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
440
441
442
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
443
444
445
446
447
448
449
450
451
452
453
454
455
                # Remapping the name of FP8 kv-scale.
                if name.endswith("kv_scale"):
                    remapped_kv_scale_name = name.replace(
                        ".kv_scale", ".attn.kv_scale")
                    if remapped_kv_scale_name not in params_dict:
                        print_warning_once(
                            f"Found kv scale in the checkpoint (e.g. {name}), "
                            "but not found the expected name in the model "
                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
                            "not loaded.")
                        continue
                    else:
                        name = remapped_kv_scale_name
456
457
458
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
gaoqiong's avatar
gaoqiong committed
459
460
                weight_loader(param, loaded_weight)  
            
461
        if self.use_llama_nn and self.quant_method is None:
gaoqiong's avatar
gaoqiong committed
462
463
464
465
466
467
468
469
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
470
471
472
            lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            qkv_words = "|".join(lay_qkv_words)          
            
gaoqiong's avatar
gaoqiong committed
473
474
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
475
476
477
478
                if matches:         
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
479
480
481
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
482
                                 
gaoqiong's avatar
gaoqiong committed
483
484
485
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
486
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
487
488
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
489
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
515
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
516
                    
gaoqiong's avatar
gaoqiong committed
517
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
518
519
520
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
521
                    
gaoqiong's avatar
gaoqiong committed
522
523
524
525
526
527
528
529
530
531
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
                         
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
551
                layer_self_attn.attn._kv_scale = scaling_factor
552
553
554
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")