llama.py 22.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
gaoqiong's avatar
gaoqiong committed
30
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig, LoRAConfig
34
35
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.model_executor.layers.sampler import Sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
49
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import SamplerOutput
52
from vllm.utils import is_hip, print_warning_once
Woosuk Kwon's avatar
Woosuk Kwon committed
53

gaoqiong's avatar
gaoqiong committed
54
from vllm import _custom_ops as ops
55
56
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
57
58

class LlamaMLP(nn.Module):
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
    def __init__(
        self,
62
63
64
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
65
        quant_config: Optional[QuantizationConfig] = None,
66
        bias: bool = False,
67
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        super().__init__()
69
        self.gate_up_proj = MergedColumnParallelLinear(
70
71
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
72
            bias=bias,
73
            quant_config=quant_config)
74
75
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
76
                                           bias=bias,
77
                                           quant_config=quant_config)
78
79
80
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83

    def forward(self, x):
84
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89
90
91
92
93
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
94
95
96
97
98
99
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
100
        quant_config: Optional[QuantizationConfig] = None,
101
        bias: bool = False,
102
        cache_config: Optional[CacheConfig] = None,
103
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
104
        super().__init__()
105
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
106
        tp_size = get_tensor_model_parallel_world_size()
107
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
108
109
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
110
        self.total_num_kv_heads = num_kv_heads
111
112
113
114
115
116
117
118
119
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
120
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
121
122
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
123
        self.scaling = self.head_dim**-0.5
124
125
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
126

127
        self.qkv_proj = QKVParallelLinear(
128
129
130
131
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
132
            bias=bias,
133
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        )
135
        self.o_proj = RowParallelLinear(
136
137
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
138
            bias=bias,
139
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
140
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
141

142
143
144
145
146
147
148
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
149
150
151
152
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
153
154
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157

    def forward(
        self,
158
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        hidden_states: torch.Tensor,
160
161
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
162
    ) -> torch.Tensor:
163
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
164
        if os.environ.get('FA_PAD') == '1':
165
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
166
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
167
        q, k = self.rotary_emb(positions, q, k)
168
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
173
174
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

175
176
177
    def __init__(
        self,
        config: LlamaConfig,
178
        cache_config: Optional[CacheConfig] = None,
179
        quant_config: Optional[QuantizationConfig] = None,
180
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
        super().__init__()
        self.hidden_size = config.hidden_size
183
184
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
185
186
187
188
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
189
190
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
191
192
193
194
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
195
        self.self_attn = LlamaAttention(
196
197
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
198
199
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
200
201
202
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
203
            quant_config=quant_config,
204
            bias=attention_bias,
205
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
        )
        self.mlp = LlamaMLP(
208
209
210
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
211
            quant_config=quant_config,
212
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        )
214
215
216
217
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220

    def forward(
        self,
221
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
222
        hidden_states: torch.Tensor,
223
224
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
225
226
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
227
        # Self Attention
228
229
230
231
232
233
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
236
237
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
238
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
        )

        # Fully Connected
242
243
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
244
        hidden_states = self.mlp(hidden_states)
245
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
249


class LlamaModel(nn.Module):

250
251
252
    def __init__(
        self,
        config: LlamaConfig,
253
        cache_config: Optional[CacheConfig] = None,
254
        quant_config: Optional[QuantizationConfig] = None,
255
        lora_config: Optional[LoRAConfig] = None,
256
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
260
261
262
263
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
264
        self.embed_tokens = VocabParallelEmbedding(
265
            self.vocab_size,
266
            config.hidden_size,
267
            org_num_embeddings=config.vocab_size,
268
        )
269
        self.layers = nn.ModuleList([
270
271
272
273
            LlamaDecoderLayer(config=config,
                              cache_config=cache_config,
                              quant_config=quant_config)
            for idx in range(config.num_hidden_layers)
274
        ])
275
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
276

277
278
279
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
    def forward(
        self,
282
        input_ids: Optional[torch.Tensor],
283
        positions: torch.Tensor,
284
285
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
286
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
287
    ) -> torch.Tensor:
288
289
290
291
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
292
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
        for i in range(len(self.layers)):
            layer = self.layers[i]
295
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
298
                positions,
                hidden_states,
                kv_caches[i],
299
                attn_metadata,
300
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
            )
302
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
306
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
321
322
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
323
324
325
326
327
328
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
329
330
331
332
333
334
335
336
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
337

338
339
340
    def __init__(
        self,
        config: LlamaConfig,
341
        cache_config: Optional[CacheConfig] = None,
342
        quant_config: Optional[QuantizationConfig] = None,
343
        lora_config: Optional[LoRAConfig] = None,
344
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
        super().__init__()
        self.config = config
347
348
349
350
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
351
        self.unpadded_vocab_size = config.vocab_size
352
        if lora_config:
Terry's avatar
Terry committed
353
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
354
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
355
            self.unpadded_vocab_size,
356
357
358
359
360
361
362
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
363
364
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
365
366
367
368
369

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
370
371
372
373
374
375
        self.quant_method =  None
     
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
gaoqiong's avatar
gaoqiong committed
376
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
377
378
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
379
380
381

    def forward(
        self,
382
383
        input_ids: torch.Tensor,
        positions: torch.Tensor,
384
385
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
386
    ) -> torch.Tensor:
387
        hidden_states = self.model(input_ids, positions, kv_caches,
388
                                   attn_metadata)
389
390
        return hidden_states

391
392
393
394
395
396
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

397
398
    def sample(
        self,
399
        logits: torch.Tensor,
400
        sampling_metadata: SamplingMetadata,
401
    ) -> Optional[SamplerOutput]:
402
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
403
404
        return next_tokens

405
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
406
407
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
408
409
410
411
412
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
413
        ]
414
        params_dict = dict(self.named_parameters())
415
        for name, loaded_weight in weights:
416
417
            if "rotary_emb.inv_freq" in name:
                continue
418
419
420
421
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
422
                continue
423
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
424
                if weight_name not in name:
425
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
426
427
428
429
430
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
431
432
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
433
                break
434
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
435
436
437
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
438
439
440
441
442
443
444
445
446
447
448
449
450
                # Remapping the name of FP8 kv-scale.
                if name.endswith("kv_scale"):
                    remapped_kv_scale_name = name.replace(
                        ".kv_scale", ".attn.kv_scale")
                    if remapped_kv_scale_name not in params_dict:
                        print_warning_once(
                            f"Found kv scale in the checkpoint (e.g. {name}), "
                            "but not found the expected name in the model "
                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
                            "not loaded.")
                        continue
                    else:
                        name = remapped_kv_scale_name
451
452
453
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
gaoqiong's avatar
gaoqiong committed
454
455
                weight_loader(param, loaded_weight)  
            
456
        if self.use_llama_nn and self.quant_method is None:
gaoqiong's avatar
gaoqiong committed
457
458
459
460
461
462
463
464
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
465
466
467
            lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            qkv_words = "|".join(lay_qkv_words)          
            
gaoqiong's avatar
gaoqiong committed
468
469
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
470
471
472
473
                if matches:         
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
474
475
476
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
477
                                 
gaoqiong's avatar
gaoqiong committed
478
479
480
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
481
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
482
483
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
484
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
510
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
511
                    
gaoqiong's avatar
gaoqiong committed
512
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
513
514
515
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
516
                    
gaoqiong's avatar
gaoqiong committed
517
518
519
520
521
522
523
524
525
526
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
                    if dim_k % 4096==0:
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
                         
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
546
                layer_self_attn.attn._kv_scale = scaling_factor
547
548
549
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")