Unverified Commit 80572c83 authored by brayden-hai's avatar brayden-hai Committed by GitHub
Browse files

[ModelOpt] Respect `kv_cache_quant_algo` in ModelOpt checkpoints (#10336)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent 4bb08f6e
...@@ -135,6 +135,7 @@ from sglang.srt.utils import ( ...@@ -135,6 +135,7 @@ from sglang.srt.utils import (
is_no_spec_infer_or_topk_one, is_no_spec_infer_or_topk_one,
is_npu, is_npu,
is_sm100_supported, is_sm100_supported,
log_info_on_rank0,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
parse_connector_type, parse_connector_type,
...@@ -1352,6 +1353,17 @@ class ModelRunner: ...@@ -1352,6 +1353,17 @@ class ModelRunner:
): ):
# Determine the kv cache dtype # Determine the kv cache dtype
if self.server_args.kv_cache_dtype == "auto": if self.server_args.kv_cache_dtype == "auto":
quant_config = getattr(self.model, "quant_config", None)
kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
if (
isinstance(kv_cache_quant_algo, str)
and kv_cache_quant_algo.upper() == "FP8"
):
if _is_hip:
self.kv_cache_dtype = torch.float8_e4m3fnuz
else:
self.kv_cache_dtype = torch.float8_e4m3fn
else:
self.kv_cache_dtype = self.dtype self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2": elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if _is_hip: # Using natively supported format if _is_hip: # Using natively supported format
...@@ -1368,6 +1380,8 @@ class ModelRunner: ...@@ -1368,6 +1380,8 @@ class ModelRunner:
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
) )
log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if SGLANG_CI_SMALL_KV_SIZE: if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE) self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment