Unverified Commit c84e9242 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Minor] Fix a dtype bug (#79)

parent c9d5b6d4
...@@ -37,7 +37,11 @@ _MEMORY_ANALYZERS = { ...@@ -37,7 +37,11 @@ _MEMORY_ANALYZERS = {
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype: def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
config_dtype: torch.dtype = getattr(config, 'torch_dtype', torch.float32) # NOTE: getattr(config, 'torch_dtype', torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if dtype == 'default': if dtype == 'default':
if config_dtype == torch.float32: if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models. # Following the common practice, we use float16 for float32 models.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment