Unverified Commit 6038b1b0 authored by amitz-nv's avatar amitz-nv Committed by GitHub
Browse files

[Frontend][Model] Add 'float16' to possible mamba cache dtype values, override...


[Frontend][Model] Add 'float16' to possible mamba cache dtype values, override mamba SSM cache dtype value for NemotronH (#29978)
Signed-off-by: default avataramitz-nv <203509407+amitz-nv@users.noreply.github.com>
parent 60a66ea2
......@@ -29,7 +29,7 @@ CacheDType = Literal[
"fp8_inc",
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"]
MambaDType = Literal["auto", "float32", "float16"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
......
......@@ -485,6 +485,26 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
(or not explicitly set), to the value specified in the HF config, or to
float16 if not specified.
"""
cache_config = vllm_config.cache_config
if cache_config.mamba_ssm_cache_dtype == "auto":
hf_config = vllm_config.model_config.hf_config
mamba_ssm_cache_dtype = getattr(
hf_config, "mamba_ssm_cache_dtype", "float16"
)
logger.info(
"Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
mamba_ssm_cache_dtype,
)
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
......@@ -502,4 +522,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"NemotronHForCausalLM": NemotronHForCausalLMConfig,
}
......@@ -28,6 +28,7 @@ else:
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment