Unverified Commit 2858830c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Prioritize dtype in root config before checking text config (#17629)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d6484ef3
...@@ -2954,10 +2954,12 @@ def _get_and_verify_dtype( ...@@ -2954,10 +2954,12 @@ def _get_and_verify_dtype(
) -> torch.dtype: ) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config.get_text_config(), "torch_dtype", None) config_dtype = getattr(config, "torch_dtype", None)
# Fallback for multi-modal models if the root config # Fallbacks for multi-modal models if the root config
# does not define torch_dtype # does not define torch_dtype
if config_dtype is None:
config_dtype = getattr(config.get_text_config(), "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"): if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None) config_dtype = getattr(config.vision_config, "torch_dtype", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment