Unverified Commit 6ec0d8db authored by danielafrimi's avatar danielafrimi Committed by GitHub
Browse files

[Fix]Load kv-cache dtype from hf_quant_config.json automatically (#29980)


Signed-off-by: default avatarDaniel Afrimi <dafrimi@nvidia.com>
parent 9693dd0f
...@@ -194,12 +194,33 @@ def get_kv_cache_torch_dtype( ...@@ -194,12 +194,33 @@ def get_kv_cache_torch_dtype(
return torch_dtype return torch_dtype
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
quant_method = quant_cfg.get("quant_method", "")
if quant_method.startswith("modelopt"):
quantization_inner = quant_cfg.get("quantization", quant_cfg)
# Check if quant config is specified and use kv cache quant algo
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
"kv_cache_quant_algo"
)
if isinstance(kv_algo, str):
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo.lower()]
return None
def kv_cache_dtype_str_to_dtype( def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype: ) -> torch.dtype:
# Model config may not be specified for unit tests, default to float16
dtype = model_config.dtype if model_config else torch.half
if kv_cache_dtype == "auto": if kv_cache_dtype == "auto":
# Model config may not be specified for unit tests, default to float16 hf_cfg = getattr(model_config, "hf_config", None)
return model_config.dtype if model_config else torch.half if hf_cfg is not None:
quant_cfg = getattr(hf_cfg, "quantization_config", None)
if quant_cfg is not None:
kv_algo_dtype = get_kv_cache_quant_algo_dtype(quant_cfg)
return kv_algo_dtype if kv_algo_dtype is not None else dtype
return dtype
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment