Unverified Commit 4bf35ed9 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Only add `Attention.kv_scale` if kv cache quantization is enabled (#5936)

parent be0b3af9
...@@ -9,6 +9,7 @@ from vllm.attention.selector import get_attn_backend ...@@ -9,6 +9,7 @@ from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
class Attention(nn.Module): class Attention(nn.Module):
...@@ -56,13 +57,17 @@ class Attention(nn.Module): ...@@ -56,13 +57,17 @@ class Attention(nn.Module):
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self) if quant_config else None self) if quant_config else None
if quant_method is not None: if quant_method is not None:
assert isinstance(quant_method, Fp8KVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if "fp8" in self.kv_cache_dtype:
if self.kv_cache_dtype == "fp8_e5m2": if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with " raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.") "fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter # When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint. # "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back # The kv_scale will then be converted back to self._kv_scale
# to self._kv_scale in a native float32 value after weight loading. # in a native float32 value after weight loading.
self.quant_method = quant_method self.quant_method = quant_method
self.quant_method.create_weights(self) self.quant_method.create_weights(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment