Unverified Commit a1344dbf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix dtype getter (#17668)

* Fix dtype getters

* Proper fix for dtype getter

* Style and commant

* Always use last for consistency

* Quality
parent 73083581
...@@ -139,7 +139,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu ...@@ -139,7 +139,7 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
try: try:
return next(parameter.parameters()).dtype return next(parameter.parameters()).dtype
except StopIteration: except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5 # For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
...@@ -152,31 +152,33 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu ...@@ -152,31 +152,33 @@ def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "Modu
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
""" """
Returns the first found floating dtype in parameters if there is one, otherwise returns the first dtype it found. Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
""" """
try: last_dtype = None
for t in parameter.parameters(): for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point(): if t.is_floating_point():
return t.dtype return t.dtype
# if no floating dtype was found return whatever the first dtype is
else:
return next(parameter.parameters()).dtype
except StopIteration: if last_dtype is not None:
# For nn.DataParallel compatibility in PyTorch 1.5 # if no floating dtype was found return whatever the first dtype is
return last_dtype
else:
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes) gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen: for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point(): if tuple[1].is_floating_point():
return tuple[1].dtype return tuple[1].dtype
# fallback to any dtype the model has even if not floating
else: # fallback to the last dtype
first_tuple = next(gen) return last_tuple[1].dtype
return first_tuple[1].dtype
def get_state_dict_float_dtype(state_dict): def get_state_dict_float_dtype(state_dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment