Unverified Commit b8809091 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype getters (#17656)

parent fd1e6703
......@@ -152,7 +152,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
Returns the first found floating dtype in parameters if there is one, otherwise returns the first dtype it found.
"""
try:
for t in parameter.parameters():
......@@ -160,7 +160,7 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return t.dtype
return next(parameter.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
......@@ -175,7 +175,8 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
return tuple[1].dtype
# fallback to any dtype the model has even if not floating
else:
return tuple[1].dtype
first_tuple = next(gen)
return first_tuple[1].dtype
def get_state_dict_float_dtype(state_dict):
......@@ -191,7 +192,7 @@ def get_state_dict_float_dtype(state_dict):
def get_state_dict_dtype(state_dict):
"""
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the last dtype.
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
"""
for t in state_dict.values():
if t.is_floating_point():
......@@ -199,7 +200,7 @@ def get_state_dict_dtype(state_dict):
# if no floating dtype was found return whatever the first dtype is
else:
return t.dtype
return next(state_dict.values()).dtype
def convert_file_size_to_int(size: Union[int, str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment