Unverified Commit b8809091 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype getters (#17656)

parent fd1e6703
...@@ -152,7 +152,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu ...@@ -152,7 +152,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
""" """
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. Returns the first found floating dtype in parameters if there is one, otherwise returns the first dtype it found.
""" """
try: try:
for t in parameter.parameters(): for t in parameter.parameters():
...@@ -160,7 +160,7 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -160,7 +160,7 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
return t.dtype return t.dtype
# if no floating dtype was found return whatever the first dtype is # if no floating dtype was found return whatever the first dtype is
else: else:
return t.dtype return next(parameter.parameters()).dtype
except StopIteration: except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5 # For nn.DataParallel compatibility in PyTorch 1.5
...@@ -175,7 +175,8 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil ...@@ -175,7 +175,8 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
return tuple[1].dtype return tuple[1].dtype
# fallback to any dtype the model has even if not floating # fallback to any dtype the model has even if not floating
else: else:
return tuple[1].dtype first_tuple = next(gen)
return first_tuple[1].dtype
def get_state_dict_float_dtype(state_dict): def get_state_dict_float_dtype(state_dict):
...@@ -191,7 +192,7 @@ def get_state_dict_float_dtype(state_dict): ...@@ -191,7 +192,7 @@ def get_state_dict_float_dtype(state_dict):
def get_state_dict_dtype(state_dict): def get_state_dict_dtype(state_dict):
""" """
Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the last dtype. Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
""" """
for t in state_dict.values(): for t in state_dict.values():
if t.is_floating_point(): if t.is_floating_point():
...@@ -199,7 +200,7 @@ def get_state_dict_dtype(state_dict): ...@@ -199,7 +200,7 @@ def get_state_dict_dtype(state_dict):
# if no floating dtype was found return whatever the first dtype is # if no floating dtype was found return whatever the first dtype is
else: else:
return t.dtype return next(state_dict.values()).dtype
def convert_file_size_to_int(size: Union[int, str]): def convert_file_size_to_int(size: Union[int, str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment