"tests/vscode:/vscode.git/clone" did not exist on "d2b52805f24ba63e3bc7d873f74caa17ac9ff36f"
Unverified Commit 116f4be4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 7b01d97a
......@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
# Concatenate q if it's a tuple (ql_nope, q_pe)
......
......@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -291,7 +292,7 @@ if current_platform.is_rocm():
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
QUANT = False
if kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(kv_cache_dtype):
QUANT = True
grid = (
num_tokens,
......@@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder(
if (
rocm_aiter_ops.is_shuffle_kv_cache_enabled()
and self.scale.numel() == 1
and self.vllm_config.cache_config.cache_dtype.startswith("fp8")
and is_quantized_kv_cache(self.vllm_config.cache_config.cache_dtype)
):
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
first_layer_name = [k for k in layers][0]
......@@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=swa_cu_seqlens,
token_to_batch=swa_token_to_batch,
seq_starts=swa_seq_starts,
dequant=self.kv_cache_dtype.startswith("fp8"),
dequant=is_quantized_kv_cache(self.kv_cache_dtype),
kv_cache_layout="NHD",
total_tokens=swa_total_tokens,
)
......@@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
token_to_batch=token_to_batch[chunk_idx],
seq_starts=chunk_starts[chunk_idx],
dequant=self.kv_cache_dtype.startswith("fp8"),
dequant=is_quantized_kv_cache(self.kv_cache_dtype),
kv_cache_layout="SHUFFLE"
if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
else "NHD",
......@@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
......@@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
# Reshape the input keys and values and store them in the cache.
......@@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
......
......@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
......@@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
softmax_scale = self.scale
fp8_post_attn_v_rescale = False
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
......@@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl):
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
......@@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
......@@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl):
)
flash_layout = False
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
......
......@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
......@@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl):
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
......@@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache.
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
......@@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
......
......@@ -5,6 +5,7 @@ import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_quantized_kv_cache
@triton.jit
......@@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash(
block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8")
if is_quantized_kv_cache(kv_cache_dtype)
else key_cache.dtype
)
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(
kv_cache_dtype
):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype)
......@@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash(
"uint8 is not supported by triton reshape_and_cache_flash"
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
......@@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv(
block_stride = kv_cache.stride()[0]
page_stride = kv_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
)
kv_cache_torch_dtype = (
current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8")
if is_quantized_kv_cache(kv_cache_dtype)
else kv_cache.dtype
)
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
if kv_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(kv_cache_dtype):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
kv_cache = kv_cache.view(kv_cache_torch_dtype)
......@@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv(
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
)
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
......
......@@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import (
get_dtype_size,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.attention.backend import (
......@@ -896,7 +897,7 @@ class GPUModelRunner(
If these are left at 0.0 (default after wake_up), all KV cache values
become effectively zero, causing gibberish output.
"""
if not self.cache_config.cache_dtype.startswith("fp8"):
if not is_quantized_kv_cache(self.cache_config.cache_dtype):
return
kv_caches = getattr(self, "kv_caches", [])
......
......@@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.utils.torch_utils import is_quantized_kv_cache, set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
......@@ -197,7 +197,7 @@ class Worker(WorkerBase):
# especially the FP8 scaling factor.
if (
(tags is None or "kv_cache" in tags)
and self.cache_config.cache_dtype.startswith("fp8")
and is_quantized_kv_cache(self.cache_config.cache_dtype)
and hasattr(self.model_runner, "init_fp8_kv_scales")
):
self.model_runner.init_fp8_kv_scales()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment