Unverified Commit e97f802b authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 6e650f56
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
...@@ -57,10 +58,12 @@ class Attention(nn.Module): ...@@ -57,10 +58,12 @@ class Attention(nn.Module):
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16 block_size = 16
is_attention_free = False is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
...@@ -70,8 +73,15 @@ class Attention(nn.Module): ...@@ -70,8 +73,15 @@ class Attention(nn.Module):
# expect the pre-quantized k/v_scale to be loaded along # expect the pre-quantized k/v_scale to be loaded along
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self._k_scale = 1.0 self.calculate_kv_scales = calculate_kv_scales
self._v_scale = 1.0 self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
self._k_scale_float = 1.0
self._v_scale_float = 1.0
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None self, prefix=prefix) if quant_config else None
if quant_method is not None: if quant_method is not None:
...@@ -127,6 +137,9 @@ class Attention(nn.Module): ...@@ -127,6 +137,9 @@ class Attention(nn.Module):
).parallel_config.pipeline_parallel_size) ).parallel_config.pipeline_parallel_size)
] ]
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -135,6 +148,9 @@ class Attention(nn.Module): ...@@ -135,6 +148,9 @@ class Attention(nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output: if self.use_output:
output = torch.empty_like(query) output = torch.empty_like(query)
hidden_size = query.size(-1) hidden_size = query.size(-1)
...@@ -161,6 +177,14 @@ class Attention(nn.Module): ...@@ -161,6 +177,14 @@ class Attention(nn.Module):
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)
def calc_kv_scales(self, key, value):
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore
......
...@@ -52,8 +52,8 @@ class _PagedAttention: ...@@ -52,8 +52,8 @@ class _PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
*args, *args,
) -> None: ) -> None:
ops.reshape_and_cache( ops.reshape_and_cache(
...@@ -80,8 +80,8 @@ class _PagedAttention: ...@@ -80,8 +80,8 @@ class _PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
*args, *args,
) -> None: ) -> None:
tp_rank: int = 0 tp_rank: int = 0
...@@ -149,8 +149,8 @@ class _IPEXPagedAttention(_PagedAttention): ...@@ -149,8 +149,8 @@ class _IPEXPagedAttention(_PagedAttention):
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
*args, *args,
) -> None: ) -> None:
ipex_modules.PagedAttention.reshape_and_cache( ipex_modules.PagedAttention.reshape_and_cache(
...@@ -170,8 +170,8 @@ class _IPEXPagedAttention(_PagedAttention): ...@@ -170,8 +170,8 @@ class _IPEXPagedAttention(_PagedAttention):
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
*args, *args,
) -> None: ) -> None:
block_size = value_cache.shape[2] block_size = value_cache.shape[2]
......
...@@ -69,8 +69,8 @@ class PagedAttention: ...@@ -69,8 +69,8 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
) -> None: ) -> None:
ops.reshape_and_cache( ops.reshape_and_cache(
key, key,
...@@ -95,8 +95,8 @@ class PagedAttention: ...@@ -95,8 +95,8 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -204,8 +204,8 @@ class PagedAttention: ...@@ -204,8 +204,8 @@ class PagedAttention:
max_query_len: int, max_query_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int], sliding_window: Optional[int],
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)
context_attention_fwd( context_attention_fwd(
......
...@@ -133,7 +133,7 @@ if triton.__version__ >= "2.1.0": ...@@ -133,7 +133,7 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [D,N] other=0.0) # [D,N]
if k_load.dtype.is_fp8(): if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype) k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else: else:
k = k_load k = k_load
...@@ -181,7 +181,7 @@ if triton.__version__ >= "2.1.0": ...@@ -181,7 +181,7 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) < cur_batch_ctx_len), ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D] other=0.0) # [N,D]
if v_load.dtype.is_fp8(): if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype) v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else: else:
v = v_load v = v_load
p = p.to(v.dtype) p = p.to(v.dtype)
...@@ -564,7 +564,7 @@ if triton.__version__ >= "2.1.0": ...@@ -564,7 +564,7 @@ if triton.__version__ >= "2.1.0":
other=0.0) # [D,N] other=0.0) # [D,N]
if k_load.dtype.is_fp8(): if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype) k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else: else:
k = k_load k = k_load
...@@ -604,7 +604,7 @@ if triton.__version__ >= "2.1.0": ...@@ -604,7 +604,7 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) < cur_batch_ctx_len), ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) other=0.0)
if v_load.dtype.is_fp8(): if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype) v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype)
else: else:
v = v_load v = v_load
p = p.to(v.dtype) p = p.to(v.dtype)
...@@ -713,8 +713,8 @@ if triton.__version__ >= "2.1.0": ...@@ -713,8 +713,8 @@ if triton.__version__ >= "2.1.0":
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale: float = 1.0, k_scale: torch.Tensor,
v_scale: float = 1.0, v_scale: torch.Tensor,
alibi_slopes=None, alibi_slopes=None,
sliding_window=None): sliding_window=None):
......
...@@ -120,11 +120,6 @@ class ModelConfig: ...@@ -120,11 +120,6 @@ class ModelConfig:
decoding draft models. decoding draft models.
quantization: Quantization method that was used to quantize the model quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized. weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode. disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
...@@ -187,7 +182,6 @@ class ModelConfig: ...@@ -187,7 +182,6 @@ class ModelConfig:
factors.append(self.model) factors.append(self.model)
factors.append(self.dtype) factors.append(self.dtype)
factors.append(self.quantization) factors.append(self.quantization)
factors.append(self.quantization_param_path)
factors.append(self.revision) factors.append(self.revision)
factors.append(self.code_revision) factors.append(self.code_revision)
factors.append(self.trust_remote_code) factors.append(self.trust_remote_code)
...@@ -213,7 +207,6 @@ class ModelConfig: ...@@ -213,7 +207,6 @@ class ModelConfig:
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None, enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20, max_logprobs: int = 20,
...@@ -274,7 +267,6 @@ class ModelConfig: ...@@ -274,7 +267,6 @@ class ModelConfig:
else: else:
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.quantization_param_path = quantization_param_path
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
...@@ -1002,6 +994,7 @@ class CacheConfig: ...@@ -1002,6 +994,7 @@ class CacheConfig:
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False, enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
...@@ -1012,7 +1005,7 @@ class CacheConfig: ...@@ -1012,7 +1005,7 @@ class CacheConfig:
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
self._verify_args() self._verify_args()
self._verify_cache_dtype() self._verify_cache_dtype()
self._verify_prefix_caching() self._verify_prefix_caching()
...@@ -1021,6 +1014,10 @@ class CacheConfig: ...@@ -1021,6 +1014,10 @@ class CacheConfig:
self.num_gpu_blocks: Optional[int] = None self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None self.num_cpu_blocks: Optional[int] = None
# Set calculate_kv_scales to False if the value is unset.
if self.calculate_kv_scales is None:
self.calculate_kv_scales = False
def metrics_info(self): def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus # convert cache_config to dict(key: str, value: str) for prometheus
# metrics info # metrics info
...@@ -3297,7 +3294,6 @@ class VllmConfig: ...@@ -3297,7 +3294,6 @@ class VllmConfig:
f"quantization={self.model_config.quantization}, " f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, " f"enforce_eager={self.model_config.enforce_eager}, "
f"kv_cache_dtype={self.cache_config.cache_dtype}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, "
f"quantization_param_path={self.model_config.quantization_param_path},"
f" device_config={self.device_config.device}, " f" device_config={self.device_config.device}, "
f"decoding_config={self.decoding_config!r}, " f"decoding_config={self.decoding_config!r}, "
f"observability_config={self.observability_config!r}, " f"observability_config={self.observability_config!r}, "
......
...@@ -98,7 +98,6 @@ class EngineArgs: ...@@ -98,7 +98,6 @@ class EngineArgs:
config_format: ConfigFormat = ConfigFormat.AUTO config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
...@@ -199,6 +198,8 @@ class EngineArgs: ...@@ -199,6 +198,8 @@ class EngineArgs:
generation_config: Optional[str] = None generation_config: Optional[str] = None
enable_sleep_mode: bool = False enable_sleep_mode: bool = False
calculate_kv_scales: Optional[bool] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
self.tokenizer = self.model self.tokenizer = self.model
...@@ -350,17 +351,6 @@ class EngineArgs: ...@@ -350,17 +351,6 @@ class EngineArgs:
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'KV cache dtype is FP8. Otherwise, KV cache scaling factors '
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version '
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
...@@ -962,6 +952,15 @@ class EngineArgs: ...@@ -962,6 +952,15 @@ class EngineArgs:
help="Enable sleep mode for the engine. " help="Enable sleep mode for the engine. "
"(only cuda platform is supported)") "(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
return parser return parser
@classmethod @classmethod
...@@ -991,7 +990,6 @@ class EngineArgs: ...@@ -991,7 +990,6 @@ class EngineArgs:
tokenizer_revision=self.tokenizer_revision, tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
quantization=self.quantization, quantization=self.quantization,
quantization_param_path=self.quantization_param_path,
enforce_eager=self.enforce_eager, enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture, max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs, max_logprobs=self.max_logprobs,
...@@ -1068,6 +1066,7 @@ class EngineArgs: ...@@ -1068,6 +1066,7 @@ class EngineArgs:
sliding_window=model_config.get_sliding_window(), sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching, enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb, cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
) )
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
......
...@@ -73,6 +73,8 @@ if TYPE_CHECKING: ...@@ -73,6 +73,8 @@ if TYPE_CHECKING:
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
...@@ -474,6 +476,13 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -474,6 +476,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_V1": "VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
# Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
# If set, enable multiprocessing in LLM for the V1 code path. # If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING": "VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -40,11 +41,16 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -40,11 +41,16 @@ class BaseKVCacheMethod(QuantizeMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint. # regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto": # No need to process kv scales after loading if we are going to
# calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
if layer.k_scale > 0.0 and layer.v_scale > 0.0: if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present # We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist() k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0: elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative # If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0 # values), use the default value of 1.0
...@@ -58,6 +64,9 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -58,6 +64,9 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale) scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist() k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm():
k_scale *= 2
v_scale *= 2
if not isinstance(k_scale, float) or not isinstance( if not isinstance(k_scale, float) or not isinstance(
v_scale, float): v_scale, float):
...@@ -65,9 +74,11 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -65,9 +74,11 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"for fp8 KV cache") "for fp8 KV cache")
# These are used in the final Attention.forward() # These are used in the final Attention.forward()
layer._k_scale = k_scale layer._k_scale.copy_(k_scale)
layer._v_scale = v_scale layer._v_scale.copy_(v_scale)
if (layer._k_scale == 1.0 and layer._v_scale == 1.0 layer._k_scale_float = k_scale
layer._v_scale_float = v_scale
if (k_scale == 1.0 and v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype): and "e5m2" not in layer.kv_cache_dtype):
logger.warning_once( logger.warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This " "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
......
...@@ -6,8 +6,7 @@ import json ...@@ -6,8 +6,7 @@ import json
import os import os
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Tuple, Union)
import filelock import filelock
import gguf import gguf
...@@ -23,7 +22,6 @@ from vllm.distributed import get_tensor_model_parallel_rank ...@@ -23,7 +22,6 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig, from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config) get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
...@@ -496,47 +494,6 @@ def gguf_quant_weights_iterator( ...@@ -496,47 +494,6 @@ def gguf_quant_weights_iterator(
yield name, param yield name, param
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
Keep this function in sync with the output of
examples/other/fp8/extract_scales.py
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.", tp_rank)
return []
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor """convert PySafeSlice object from safetensors to torch.Tensor
......
...@@ -30,8 +30,7 @@ from torch import nn ...@@ -30,8 +30,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig
...@@ -576,32 +574,3 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -576,32 +574,3 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.transformer.h[layer_idx], nn.Identity):
layer_self_attn = self.transformer.h[layer_idx].attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
...@@ -29,8 +29,7 @@ from transformers import GraniteConfig ...@@ -29,8 +29,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -518,29 +516,3 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -518,29 +516,3 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
...@@ -29,8 +29,7 @@ from transformers import LlamaConfig ...@@ -29,8 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -43,9 +42,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -43,9 +42,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -440,32 +438,6 @@ class LlamaModel(nn.Module): ...@@ -440,32 +438,6 @@ class LlamaModel(nn.Module):
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -593,9 +565,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -593,9 +565,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.maybe_remap_mistral(name, loaded_weight) self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights) for name, loaded_weight in weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# This function is used to remap the mistral format as # This function is used to remap the mistral format as
# used by Mistral and Llama <=2 # used by Mistral and Llama <=2
def maybe_remap_mistral( def maybe_remap_mistral(
......
...@@ -831,6 +831,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -831,6 +831,7 @@ class MllamaTextCrossAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1: if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN, if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1): _Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
...@@ -843,8 +844,8 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -843,8 +844,8 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata. attn_metadata.
cross_slot_mapping, # type: ignore[union-attr] cross_slot_mapping, # type: ignore[union-attr]
"auto", "auto",
1.0, i,
1.0, i,
) )
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
...@@ -853,7 +854,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -853,7 +854,7 @@ class MllamaTextCrossAttention(nn.Module):
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache( PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache, cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) attn_metadata.cross_slot_mapping, "auto", i, i)
else: else:
raise ValueError( raise ValueError(
f"Unsupported Attention backend {self.attn.backend} " f"Unsupported Attention backend {self.attn.backend} "
......
...@@ -30,8 +30,7 @@ from transformers import PretrainedConfig ...@@ -30,8 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -44,9 +43,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
...@@ -535,32 +533,3 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -535,32 +533,3 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if current_platform.is_rocm():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
...@@ -166,10 +166,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -166,10 +166,6 @@ class FlashAttentionImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
......
...@@ -903,7 +903,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -903,7 +903,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps= multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here None, # FIXME(kzawora): mutli-modality will not work here
enable_kv_scales_calculation=False,
) )
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
...@@ -1057,7 +1058,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1057,7 +1058,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None) multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return PrepareDecodeMetadata(input_tokens=input_tokens, return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
......
...@@ -3,7 +3,6 @@ import gc ...@@ -3,7 +3,6 @@ import gc
import inspect import inspect
import itertools import itertools
import time import time
import warnings
import weakref import weakref
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -41,7 +40,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes ...@@ -41,7 +40,6 @@ from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap, MultiModalKwargs, MultiModalPlaceholderMap,
MultiModalRegistry) MultiModalRegistry)
from vllm.platforms import current_platform
from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import ( from vllm.prompt_adapter.worker_manager import (
...@@ -1151,34 +1149,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1151,34 +1149,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_manager.create_prompt_adapter_manager( self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model)) self.model))
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
or current_platform.is_cuda()):
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if self.vllm_config.compilation_config.level ==\ if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
backend = self.vllm_config.compilation_config.init_backend( backend = self.vllm_config.compilation_config.init_backend(
...@@ -1366,6 +1336,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1366,6 +1336,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
# Disable KV Scale Calculation for dummy data during profile run
if model_input.attn_metadata is not None:
model_input.attn_metadata.enable_kv_scales_calculation = False
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
...@@ -1510,7 +1484,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1510,7 +1484,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size, batch_size,
is_encoder_decoder_model=self.model_config. is_encoder_decoder_model=self.model_config.
is_encoder_decoder)) is_encoder_decoder))
# Disable KV Scale Calculation for graph capture
attn_metadata.enable_kv_scales_calculation = False
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size, **dict(index_mapping=[0] * batch_size,
......
...@@ -282,6 +282,7 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -282,6 +282,7 @@ class OpenVINOModelRunner(ModelRunnerBase):
block_indices_begins=block_indices_begins_tensor, block_indices_begins=block_indices_begins_tensor,
max_context_len=max_context_len_tensor, max_context_len=max_context_len_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
) )
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
......
...@@ -190,6 +190,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -190,6 +190,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=None, block_tables=None,
context_lens=None, context_lens=None,
effective_query_lens=None, effective_query_lens=None,
...@@ -208,6 +209,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -208,6 +209,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
effective_query_lens=effective_query_lens, effective_query_lens=effective_query_lens,
...@@ -239,6 +241,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -239,6 +241,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=batch_size * seq_len, num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
...@@ -425,6 +428,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -425,6 +428,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
effective_query_lens=prompt_lens, effective_query_lens=prompt_lens,
...@@ -496,6 +500,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -496,6 +500,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
block_tables=block_tables, block_tables=block_tables,
context_lens=context_lens, context_lens=context_lens,
) )
......
...@@ -261,6 +261,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -261,6 +261,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
seq_lens=seq_lens, seq_lens=seq_lens,
seqlen_q=seqlen_q, seqlen_q=seqlen_q,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
...@@ -345,6 +346,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ...@@ -345,6 +346,7 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=seq_lens, seq_lens=seq_lens,
seqlen_q=torch.tensor([]), seqlen_q=torch.tensor([]),
max_seqlen=0, max_seqlen=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment