llama.py 21.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
33
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.layernorm import RMSNorm
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
41
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
42
43
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.model_executor.layers.sampler import Sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import (
49
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors, SamplerOutput
52
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsLoRA
55
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
56

Woosuk Kwon's avatar
Woosuk Kwon committed
57
58

class LlamaMLP(nn.Module):
59

Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
    def __init__(
        self,
62
63
64
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
65
        quant_config: Optional[QuantizationConfig] = None,
66
        bias: bool = False,
67
        prefix: str = "",
68
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
69
        super().__init__()
70
        self.gate_up_proj = MergedColumnParallelLinear(
71
72
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
73
            bias=bias,
74
75
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj")
76
77
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
78
                                           bias=bias,
79
80
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.down_proj")
81
82
83
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86

    def forward(self, x):
87
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
93
94
95
96
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
97
        config: LlamaConfig,
98
99
100
101
102
103
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
104
        quant_config: Optional[QuantizationConfig] = None,
105
        bias: bool = False,
106
        cache_config: Optional[CacheConfig] = None,
107
        prefix: str = "",
108
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        super().__init__()
110
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
111
        tp_size = get_tensor_model_parallel_world_size()
112
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
113
114
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
115
        self.total_num_kv_heads = num_kv_heads
116
117
118
119
120
121
122
123
124
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
125
126
127
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
128
129
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
130
        self.scaling = self.head_dim**-0.5
131
132
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
133

134
        self.qkv_proj = QKVParallelLinear(
135
136
137
138
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
139
            bias=bias,
140
            quant_config=quant_config,
141
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        )
143
        self.o_proj = RowParallelLinear(
144
145
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
146
            bias=bias,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
150

151
152
153
154
155
156
157
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
158
159
160
161
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
162
163
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166

    def forward(
        self,
167
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
168
        hidden_states: torch.Tensor,
169
170
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
171
    ) -> torch.Tensor:
172
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
173
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
174
        q, k = self.rotary_emb(positions, q, k)
175
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178
179
180
181
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

182
183
184
    def __init__(
        self,
        config: LlamaConfig,
185
        cache_config: Optional[CacheConfig] = None,
186
        quant_config: Optional[QuantizationConfig] = None,
187
        prefix: str = "",
188
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
        super().__init__()
        self.hidden_size = config.hidden_size
191
192
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
193
194
195
196
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
197
198
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
199
200
201
202
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
203
        self.self_attn = LlamaAttention(
204
            config=config,
205
206
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
207
208
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
209
210
211
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
212
            quant_config=quant_config,
213
            bias=attention_bias,
214
            cache_config=cache_config,
215
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
        )
        self.mlp = LlamaMLP(
218
219
220
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
221
            quant_config=quant_config,
222
            bias=getattr(config, "mlp_bias", False),
223
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
224
        )
225
226
227
228
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
231

    def forward(
        self,
232
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
233
        hidden_states: torch.Tensor,
234
235
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
236
237
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
238
        # Self Attention
239
240
241
242
243
244
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247
248
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
249
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
        )

        # Fully Connected
253
254
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
255
        hidden_states = self.mlp(hidden_states)
256
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260


class LlamaModel(nn.Module):

261
262
263
    def __init__(
        self,
        config: LlamaConfig,
264
        cache_config: Optional[CacheConfig] = None,
265
        quant_config: Optional[QuantizationConfig] = None,
266
        lora_config: Optional[LoRAConfig] = None,
267
        prefix: str = "",
268
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
272
273
274
275
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
276
277
278
279
280
281
282
283
284
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()
285
        self.start_layer, self.end_layer, self.layers = make_layers(
286
            config.num_hidden_layers,
287
288
289
290
291
            lambda prefix: LlamaDecoderLayer(config=config,
                                             cache_config=cache_config,
                                             quant_config=quant_config,
                                             prefix=prefix),
            prefix=f"{prefix}.layers")
292
293
294
295
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
296

297
298
299
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
    def forward(
        self,
302
        input_ids: Optional[torch.Tensor],
303
        positions: torch.Tensor,
304
305
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
306
        intermediate_tensors: Optional[IntermediateTensors],
307
        inputs_embeds: Optional[torch.Tensor] = None,
308
309
310
311
312
313
314
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
315
        else:
316
317
318
319
320
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
321
            layer = self.layers[i]
322
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
323
324
                positions,
                hidden_states,
325
                kv_caches[i - self.start_layer],
326
                attn_metadata,
327
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
328
            )
329
330
331
332
333
334
335

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

336
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
339
        return hidden_states


340
class LlamaForCausalLM(nn.Module, SupportsLoRA):
Terry's avatar
Terry committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
355
356
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
357
358
359
360
361
362
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
363
364
365
366
367
368
369
370
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
371

372
373
374
    def __init__(
        self,
        config: LlamaConfig,
375
        cache_config: Optional[CacheConfig] = None,
376
        quant_config: Optional[QuantizationConfig] = None,
377
        lora_config: Optional[LoRAConfig] = None,
378
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
379
        super().__init__()
380

Woosuk Kwon's avatar
Woosuk Kwon committed
381
        self.config = config
382
383
        self.lora_config = lora_config

384
385
386
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
387
388
                                lora_config=lora_config,
                                prefix="model")
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE
                # We need bigger padding if using lora for kernel
                # compatibility
                if not lora_config else lora_config.lora_vocab_padding_size,
                quant_config=quant_config,
            )
            if config.tie_word_embeddings:
                self.lm_head.weight = self.model.embed_tokens.weight

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
            self.sampler = Sampler()
        else:
            self.lm_head = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414
415

    def forward(
        self,
416
417
        input_ids: torch.Tensor,
        positions: torch.Tensor,
418
419
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
420
421
422
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
Alphi's avatar
Alphi committed
423
                                  attn_metadata, intermediate_tensors)
424
        return model_output
425

426
427
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
428
        logits = self.logits_processor(self.lm_head, hidden_states,
429
430
431
                                       sampling_metadata)
        return logits

432
433
    def sample(
        self,
434
        logits: torch.Tensor,
435
        sampling_metadata: SamplingMetadata,
436
    ) -> Optional[SamplerOutput]:
437
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
438
439
        return next_tokens

440
441
442
443
444
445
446
447
448
449
450
451
452
453
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

454
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
455
456
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
457
458
459
460
461
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
462
        ]
463
        params_dict = dict(self.named_parameters())
464
        for name, loaded_weight in weights:
465
466
            if "rotary_emb.inv_freq" in name:
                continue
467
468
469
470
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
471
                continue
472
473
474
475
476
477
478
479
            if scale_name := get_compressed_tensors_cache_scale(name):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
                continue
480
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
481
                if weight_name not in name:
482
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
483
484
485
486
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
487
488
489
490
491
492
493
494

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)

495
                break
496
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
497
498
499
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
500
                # Remapping the name of FP8 kv-scale.
501
502
503
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
504
505
506
507
508
509
510
511

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
512
513
514
515
516
517
518
519
520
521
522

    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
523
524
            if not isinstance(self.model.layers[layer_idx], nn.Identity):
                layer_self_attn = self.model.layers[layer_idx].self_attn
525
526
527
528
529
530
531
532

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
533
                layer_self_attn.attn._kv_scale = scaling_factor
534
535
536
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")