Unverified Commit dfa5062a authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

NemotronH default mamba_ssm_cache_dtype=float32; enable auto-hook for...


NemotronH default mamba_ssm_cache_dtype=float32; enable auto-hook for NemotronHNanoVLV2Config (#39032)
Signed-off-by: default avatarNetanel Haber <58652339+netanel-haber@users.noreply.github.com>
parent e8ebbdde
...@@ -7,7 +7,9 @@ from vllm.logger import init_logger ...@@ -7,7 +7,9 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -346,17 +348,20 @@ class MambaModelConfig(VerifyAndUpdateConfig): ...@@ -346,17 +348,20 @@ class MambaModelConfig(VerifyAndUpdateConfig):
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig): class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod DEFAULT_MAMBA_SSM_CACHE_DTYPE = "float32"
def verify_and_update_config(vllm_config: "VllmConfig") -> None: """Only `float32` is known to have no accuracy issues by default."""
@classmethod
def update_mamba_ssm_cache_dtype(
cls, *, cache_config: "CacheConfig", hf_config: "PretrainedConfig"
) -> None:
"""Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto' """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
(or not explicitly set), to the value specified in the HF config, or to (or not explicitly set), to the value specified in the HF config, or to
float16 if not specified. `float32` if not specified.
""" """
cache_config = vllm_config.cache_config
if cache_config.mamba_ssm_cache_dtype == "auto": if cache_config.mamba_ssm_cache_dtype == "auto":
hf_config = vllm_config.model_config.hf_config
mamba_ssm_cache_dtype = getattr( mamba_ssm_cache_dtype = getattr(
hf_config, "mamba_ssm_cache_dtype", "float16" hf_config, "mamba_ssm_cache_dtype", cls.DEFAULT_MAMBA_SSM_CACHE_DTYPE
) )
logger.info( logger.info(
"Updating mamba_ssm_cache_dtype to '%s' for NemotronH model", "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
...@@ -364,8 +369,22 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig): ...@@ -364,8 +369,22 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
) )
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
cls.update_mamba_ssm_cache_dtype(
cache_config=vllm_config.cache_config,
hf_config=vllm_config.model_config.hf_config,
)
class NemotronHNanoVLV2Config(VerifyAndUpdateConfig): class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
NemotronHForCausalLMConfig.update_mamba_ssm_cache_dtype(
cache_config=vllm_config.cache_config,
hf_config=vllm_config.model_config.hf_config.text_config,
)
@staticmethod @staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None: def verify_and_update_model_config(model_config: "ModelConfig") -> None:
mm_config = model_config.multimodal_config mm_config = model_config.multimodal_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment