Unverified Commit 2ff767b5 authored by Adrian Abeyta's avatar Adrian Abeyta Committed by GitHub
Browse files

Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)


Co-authored-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarHaiShaw <hixiao@gmail.com>
Co-authored-by: default avatarAdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: default avatarroot <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: default avatarmawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: default avatarttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: default avatarguofangze <guofangze@kuaishou.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarjacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 3dcb3e8b
......@@ -23,7 +23,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d,
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
......@@ -120,6 +120,26 @@ class ModelRunner:
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
else:
raise RuntimeError("Using FP8 KV cache and scaling "
"factors provided but model "
f"{self.model.__class__} does not "
"support loading scaling factors.")
else:
logger.warn("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warn("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment