Unverified Commit 75343de9 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling_utils] torch_dtype/auto floating dtype fixes (#17614)



* [modeling_utils] torch_dtype/auto fixes

* add test

* apply suggestions

* add missing fallback

* Renaming things

* Use for else
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>
parent c38f4e1f
...@@ -132,7 +132,10 @@ def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUti ...@@ -132,7 +132,10 @@ def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUti
return first_tuple[1].device return first_tuple[1].device
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
"""
try: try:
return next(parameter.parameters()).dtype return next(parameter.parameters()).dtype
except StopIteration: except StopIteration:
...@@ -147,6 +150,58 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -147,6 +150,58 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
return first_tuple[1].dtype return first_tuple[1].dtype
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
try:
for t in parameter.parameters():
if t.is_floating_point():
return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return t.dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
for tuple in gen:
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to any dtype the model has even if not floating
else:
return tuple[1].dtype
def get_state_dict_float_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` or asserts if none were found.
"""
for t in state_dict.values():
if t.is_floating_point():
return t.dtype
raise ValueError("couldn't find any floating point dtypes in state_dict")
def get_state_dict_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the last dtype.
"""
for t in state_dict.values():
if t.is_floating_point():
return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return t.dtype
def convert_file_size_to_int(size: Union[int, str]): def convert_file_size_to_int(size: Union[int, str]):
""" """
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
...@@ -2076,7 +2131,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2076,7 +2131,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# set dtype to instantiate the model under: # set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype # 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry - we assume all weights are of the same dtype # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5 # we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None dtype_orig = None
if torch_dtype is not None: if torch_dtype is not None:
...@@ -2085,10 +2140,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2085,10 +2140,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_sharded and "dtype" in sharded_metadata: if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"] torch_dtype = sharded_metadata["dtype"]
elif not is_sharded: elif not is_sharded:
torch_dtype = next(iter(state_dict.values())).dtype torch_dtype = get_state_dict_dtype(state_dict)
else: else:
one_state_dict = load_state_dict(resolved_archive_file) one_state_dict = load_state_dict(resolved_archive_file)
torch_dtype = next(iter(one_state_dict.values())).dtype torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory del one_state_dict # free CPU memory
else: else:
raise ValueError( raise ValueError(
......
...@@ -134,6 +134,7 @@ def _config_zero_init(config): ...@@ -134,6 +134,7 @@ def _config_zero_init(config):
TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
@require_torch @require_torch
...@@ -2557,6 +2558,10 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2557,6 +2558,10 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16) model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16) self.assertEqual(model.dtype, torch.float16)
# test model whose first param is not of a floating type, but int
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
def test_no_super_init_config_and_model(self): def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32) config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config) model = NoSuperInitModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment