"vscode:/vscode.git/clone" did not exist on "efa6bed264b2dbb4c5d7a28e49fab60f6c69aef2"
Unverified Commit e97f802b authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 6e650f56
......@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config
......@@ -57,10 +58,12 @@ class Attention(nn.Module):
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
......@@ -70,8 +73,15 @@ class Attention(nn.Module):
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._k_scale = 1.0
self._v_scale = 1.0
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
self._k_scale_float = 1.0
self._v_scale_float = 1.0
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
......@@ -127,6 +137,9 @@ class Attention(nn.Module):
).parallel_config.pipeline_parallel_size)
]
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward(
self,
query: torch.Tensor,
......@@ -135,6 +148,9 @@ class Attention(nn.Module):
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
......@@ -161,6 +177,14 @@ class Attention(nn.Module):
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)
def calc_kv_scales(self, key, value):
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
......
......@@ -52,8 +52,8 @@ class _PagedAttention:
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ops.reshape_and_cache(
......@@ -80,8 +80,8 @@ class _PagedAttention:
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
tp_rank: int = 0
......@@ -149,8 +149,8 @@ class _IPEXPagedAttention(_PagedAttention):
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
......@@ -170,8 +170,8 @@ class _IPEXPagedAttention(_PagedAttention):
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
block_size = value_cache.shape[2]
......
......@@ -69,8 +69,8 @@ class PagedAttention:
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
ops.reshape_and_cache(
key,
......@@ -95,8 +95,8 @@ class PagedAttention:
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
......@@ -204,8 +204,8 @@ class PagedAttention:
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
......
......@@ -133,7 +133,7 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
......@@ -181,7 +181,7 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
......@@ -564,7 +564,7 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
......@@ -604,7 +604,7 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
......@@ -713,8 +713,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len,
b_ctx_len,
max_input_len,
k_scale: float = 1.0,
v_scale: float = 1.0,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
alibi_slopes=None,
sliding_window=None):
......
......@@ -120,11 +120,6 @@ class ModelConfig:
decoding draft models.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
......@@ -187,7 +182,6 @@ class ModelConfig:
factors.append(self.model)
factors.append(self.dtype)
factors.append(self.quantization)
factors.append(self.quantization_param_path)
factors.append(self.revision)
factors.append(self.code_revision)
factors.append(self.trust_remote_code)
......@@ -213,7 +207,6 @@ class ModelConfig:
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
......@@ -274,7 +267,6 @@ class ModelConfig:
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
......@@ -1002,6 +994,7 @@ class CacheConfig:
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
......@@ -1012,7 +1005,7 @@ class CacheConfig:
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
......@@ -1021,6 +1014,10 @@ class CacheConfig:
self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None
# Set calculate_kv_scales to False if the value is unset.
if self.calculate_kv_scales is None:
self.calculate_kv_scales = False
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
......@@ -3297,7 +3294,6 @@ class VllmConfig:
f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
f"quantization_param_path={self.model_config.quantization_param_path},"
f" device_config={self.device_config.device}, "
f"decoding_config={self.decoding_config!r}, "
f"observability_config={self.observability_config!r}, "
......
......@@ -98,7 +98,6 @@ class EngineArgs:
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
......@@ -199,6 +198,8 @@ class EngineArgs:
generation_config: Optional[str] = None
enable_sleep_mode: bool = False
calculate_kv_scales: Optional[bool] = None
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
......@@ -350,17 +351,6 @@ class EngineArgs:
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version '
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument('--max-model-len',
type=int,
default=EngineArgs.max_model_len,
......@@ -962,6 +952,15 @@ class EngineArgs:
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
return parser
@classmethod
......@@ -991,7 +990,6 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
......@@ -1068,6 +1066,7 @@ class EngineArgs:
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
......
......@@ -73,6 +73,8 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
......@@ -474,6 +476,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
......
......@@ -3,6 +3,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -40,11 +41,16 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
# No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
......@@ -58,6 +64,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
......@@ -65,9 +74,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"for fp8 KV cache")
# These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
layer._k_scale.copy_(k_scale)
layer._v_scale.copy_(v_scale)
layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if (k_scale == 1.0 and v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
logger.warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
......
......@@ -6,8 +6,7 @@ import json
import os
import tempfile
from collections import defaultdict
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional,
Tuple, Union)
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import filelock
import gguf
......@@ -23,7 +22,6 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import PlaceholderModule
......@@ -496,47 +494,6 @@ def gguf_quant_weights_iterator(
yield name, param
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of
examples/other/fp8/extract_scales.py
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.", tp_rank)
return []
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
......
......@@ -30,8 +30,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
......@@ -576,32 +574,3 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
......@@ -29,8 +29,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
......@@ -518,29 +516,3 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
......@@ -29,8 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -43,9 +42,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
......@@ -440,32 +438,6 @@ class LlamaModel(nn.Module):
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
......@@ -593,9 +565,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
......
......@@ -831,6 +831,7 @@ class MllamaTextCrossAttention(nn.Module):
) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
......@@ -843,8 +844,8 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
1.0,
1.0,
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
......@@ -853,7 +854,7 @@ class MllamaTextCrossAttention(nn.Module):
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
......
......@@ -30,8 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
......@@ -535,32 +533,3 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
......@@ -166,10 +166,6 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
......
......@@ -903,7 +903,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
None, # FIXME(kzawora): mutli-modality will not work here
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
......@@ -1057,7 +1058,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None)
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
......
......@@ -3,7 +3,6 @@ import gc
import inspect
import itertools
import time
import warnings
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -41,7 +40,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
......@@ -1151,34 +1149,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
or current_platform.is_cuda()):
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend(
......@@ -1366,6 +1336,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype,
device=self.device)
# Disable KV Scale Calculation for dummy data during profile run
if model_input.attn_metadata is not None:
model_input.attn_metadata.enable_kv_scales_calculation = False
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
......@@ -1510,7 +1484,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
# Disable KV Scale Calculation for graph capture
attn_metadata.enable_kv_scales_calculation = False
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size,
......
......@@ -282,6 +282,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
......
......@@ -190,6 +190,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None,
context_lens=None,
effective_query_lens=None,
......@@ -208,6 +209,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
......@@ -239,6 +241,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
......@@ -425,6 +428,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=prompt_lens,
......@@ -496,6 +500,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables,
context_lens=context_lens,
)
......
......@@ -261,6 +261,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
is_prompt=True,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
......@@ -345,6 +346,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens,
seqlen_q=torch.tensor([]),
max_seqlen=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment