Unverified Commit a1344dbf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype getter (#17668)

* Fix dtype getters

* Proper fix for dtype getter

* Style and commant

* Always use last for consistency

* Quality
parent 73083581
......@@ -139,7 +139,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
try:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
......@@ -152,31 +152,33 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the first dtype it found.
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
try:
last_dtype = None
for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point():
return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
# fallback to any dtype the model has even if not floating
else:
first_tuple = next(gen)
return first_tuple[1].dtype
# fallback to the last dtype
return last_tuple[1].dtype
def get_state_dict_float_dtype(state_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment