llama.py 20.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
gaoqiong's avatar
gaoqiong committed
30
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig, LoRAConfig
34
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
35
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.model_executor.layers.sampler import Sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
49
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors, SamplerOutput
52
from vllm.utils import is_hip, print_warning_once
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsLoRA
55
from .utils import is_pp_missing_parameter, make_layers
56

gaoqiong's avatar
gaoqiong committed
57
from vllm import _custom_ops as ops
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59

class LlamaMLP(nn.Module):
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
    def __init__(
        self,
63
64
65
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
66
        quant_config: Optional[QuantizationConfig] = None,
67
        bias: bool = False,
68
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
69
        super().__init__()
70
        self.gate_up_proj = MergedColumnParallelLinear(
71
72
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
73
            bias=bias,
74
            quant_config=quant_config)
75
76
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
77
                                           bias=bias,
78
                                           quant_config=quant_config)
79
80
81
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84

    def forward(self, x):
85
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
91
92
93
94
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
95
96
97
98
99
100
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
101
        quant_config: Optional[QuantizationConfig] = None,
102
        bias: bool = False,
103
        cache_config: Optional[CacheConfig] = None,
104
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        super().__init__()
106
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
107
        tp_size = get_tensor_model_parallel_world_size()
108
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
109
110
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
111
        self.total_num_kv_heads = num_kv_heads
112
113
114
115
116
117
118
119
120
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
121
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
122
123
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
124
        self.scaling = self.head_dim**-0.5
125
126
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
127

128
        self.qkv_proj = QKVParallelLinear(
129
130
131
132
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
133
            bias=bias,
134
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        )
136
        self.o_proj = RowParallelLinear(
137
138
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
139
            bias=bias,
140
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
142

143
144
145
146
147
148
149
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
150
151
152
153
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
154
155
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158

    def forward(
        self,
159
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
        hidden_states: torch.Tensor,
161
162
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
    ) -> torch.Tensor:
164
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
165
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
166
        q, k = self.rotary_emb(positions, q, k)
167
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172
173
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

174
175
176
    def __init__(
        self,
        config: LlamaConfig,
177
        cache_config: Optional[CacheConfig] = None,
178
        quant_config: Optional[QuantizationConfig] = None,
179
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
        super().__init__()
        self.hidden_size = config.hidden_size
182
183
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
184
185
186
187
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
188
189
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
190
191
192
193
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
194
        self.self_attn = LlamaAttention(
195
196
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
197
198
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
199
200
201
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
202
            quant_config=quant_config,
203
            bias=attention_bias,
204
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
        )
        self.mlp = LlamaMLP(
207
208
209
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
210
            quant_config=quant_config,
211
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
212
        )
213
214
215
216
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219

    def forward(
        self,
220
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
221
        hidden_states: torch.Tensor,
222
223
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
224
225
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        # Self Attention
227
228
229
230
231
232
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
236
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
237
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
        )

        # Fully Connected
241
242
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
243
        hidden_states = self.mlp(hidden_states)
244
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247
248


class LlamaModel(nn.Module):

249
250
251
    def __init__(
        self,
        config: LlamaConfig,
252
        cache_config: Optional[CacheConfig] = None,
253
        quant_config: Optional[QuantizationConfig] = None,
254
        lora_config: Optional[LoRAConfig] = None,
255
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
259
260
261
262
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
263
        self.embed_tokens = VocabParallelEmbedding(
264
            self.vocab_size,
265
            config.hidden_size,
266
            org_num_embeddings=config.vocab_size,
267
        )
268
        self.start_layer, self.end_layer, self.layers = make_layers(
269
            config.num_hidden_layers,
270
271
272
            lambda: LlamaDecoderLayer(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config))
273
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
274

275
276
277
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
    def forward(
        self,
280
        input_ids: Optional[torch.Tensor],
281
        positions: torch.Tensor,
282
283
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
284
        intermediate_tensors: Optional[IntermediateTensors],
285
        inputs_embeds: Optional[torch.Tensor] = None,
286
287
288
289
290
291
292
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
293
        else:
294
295
296
297
298
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
299
            layer = self.layers[i]
300
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
                positions,
                hidden_states,
303
                kv_caches[i - self.start_layer],
304
                attn_metadata,
305
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
306
            )
307
308
309
310
311
312
313

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

314
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
        return hidden_states


318
class LlamaForCausalLM(nn.Module, SupportsLoRA):
Terry's avatar
Terry committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
333
334
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
335
336
337
338
339
340
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
341
342
343
344
345
346
347
348
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
349

350
351
352
    def __init__(
        self,
        config: LlamaConfig,
353
        cache_config: Optional[CacheConfig] = None,
354
        quant_config: Optional[QuantizationConfig] = None,
355
        lora_config: Optional[LoRAConfig] = None,
356
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
357
        super().__init__()
358

Woosuk Kwon's avatar
Woosuk Kwon committed
359
        self.config = config
360
361
        self.lora_config = lora_config

362
363
364
365
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
366
        self.unpadded_vocab_size = config.vocab_size
367
        if lora_config:
Terry's avatar
Terry committed
368
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
369
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
370
            self.unpadded_vocab_size,
371
372
373
374
375
376
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
377
            quant_config=quant_config,
378
        )
379
380
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
381
382
383
384
385

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
386
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
389

    def forward(
        self,
390
391
        input_ids: torch.Tensor,
        positions: torch.Tensor,
392
393
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
394
395
396
397
398
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
                                  attn_metadata, intermediate_tensors)
        return model_output
399

400
401
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
402
        logits = self.logits_processor(self.lm_head, hidden_states,
403
404
405
                                       sampling_metadata)
        return logits

406
407
    def sample(
        self,
408
        logits: torch.Tensor,
409
        sampling_metadata: SamplingMetadata,
410
    ) -> Optional[SamplerOutput]:
411
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
        return next_tokens

414
415
416
417
418
419
420
421
422
423
424
425
426
427
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

428
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
429
430
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
431
432
433
434
435
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
436
        ]
437
        params_dict = dict(self.named_parameters())
438
        for name, loaded_weight in weights:
439
440
            if "rotary_emb.inv_freq" in name:
                continue
441
442
443
444
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
445
                continue
446
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
447
                if weight_name not in name:
448
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
449
450
451
452
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
453
454
455
456

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
457
                param = params_dict[name]
458
459
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
460

461
                break
462
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
463
464
465
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
466
467
468
469
470
471
472
473
474
475
476
477
478
                # Remapping the name of FP8 kv-scale.
                if name.endswith("kv_scale"):
                    remapped_kv_scale_name = name.replace(
                        ".kv_scale", ".attn.kv_scale")
                    if remapped_kv_scale_name not in params_dict:
                        print_warning_once(
                            f"Found kv scale in the checkpoint (e.g. {name}), "
                            "but not found the expected name in the model "
                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
                            "not loaded.")
                        continue
                    else:
                        name = remapped_kv_scale_name
479
480
481
482

                if is_pp_missing_parameter(name, self):
                    continue

483
484
485
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
gaoqiong's avatar
gaoqiong committed
486
487
488
489
490
491
492
493
494
495
496
497
498
                weight_loader(param, loaded_weight)  
            
        if self.use_llama_nn:
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                matches = re.findall(combined_words, layername)
zhuwenwen's avatar
zhuwenwen committed
499
                if matches:                  
gaoqiong's avatar
gaoqiong committed
500
501
502
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
503
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
504
505
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
506
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
507
                    
508
509
510
511
512
513
514
515
516
517
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
518
519
            if not isinstance(self.model.layers[layer_idx], nn.Identity):
                layer_self_attn = self.model.layers[layer_idx].self_attn
520
521
522
523
524
525
526
527

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
528
                layer_self_attn.attn._kv_scale = scaling_factor
529
530
531
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")