Unverified Commit f690372b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Update dtype detection and defaults (#14858)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8b3e94a3
...@@ -50,7 +50,7 @@ def _get_test_sampling_params( ...@@ -50,7 +50,7 @@ def _get_test_sampling_params(
"""Generate random sampling params for a batch.""" """Generate random sampling params for a batch."""
def get_mostly_n_gt1() -> int: def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1""" r"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28) x = random.randint(0, 28)
if x < 10: if x < 10:
return 1 return 1
......
...@@ -347,7 +347,7 @@ class ModelConfig: ...@@ -347,7 +347,7 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config() self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision) self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
...@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype( ...@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None) config_dtype = getattr(config, "torch_dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
if config_dtype is None: if config_dtype is None:
config_dtype = torch.float32 config_dtype = torch.float32
...@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype( ...@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
dtype = dtype.lower() dtype = dtype.lower()
if dtype == "auto": if dtype == "auto":
if config_dtype == torch.float32: if config_dtype == torch.float32:
if config.model_type in ("gemma2", "gemma3", "gemma3_text"): # Following common practice, we use float16 for float32 models
logger.info( torch_dtype = torch.float16
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else: else:
torch_dtype = config_dtype torch_dtype = config_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment