Unverified Commit f690372b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Update dtype detection and defaults (#14858)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8b3e94a3
......@@ -50,7 +50,7 @@ def _get_test_sampling_params(
"""Generate random sampling params for a batch."""
def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1"""
r"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28)
if x < 10:
return 1
......
......@@ -347,7 +347,7 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
......@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
......@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
logger.info(
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
# Following common practice, we use float16 for float32 models
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment