Unverified Commit 7c2f0afb authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

update `get_parameter_dtype` (#10342)

add:
q
parent f615f00f
...@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device: ...@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try: """
return next(parameter.parameters()).dtype Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
except StopIteration: """
try: last_dtype = None
return next(parameter.buffers()).dtype for param in parameter.parameters():
except StopIteration: last_dtype = param.dtype
# For torch.nn.DataParallel compatibility in PyTorch 1.5 if param.is_floating_point():
return param.dtype
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] for buffer in parameter.buffers():
return tuples last_dtype = buffer.dtype
if buffer.is_floating_point():
gen = parameter._named_members(get_members_fn=find_tensor_attributes) return buffer.dtype
first_tuple = next(gen)
return first_tuple[1].dtype if last_dtype is not None:
# if no floating dtype was found return whatever the first dtype is
return last_dtype
# For nn.DataParallel compatibility in PyTorch > 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
last_tuple = None
for tuple in gen:
last_tuple = tuple
if tuple[1].is_floating_point():
return tuple[1].dtype
if last_tuple is not None:
# fallback to the last dtype
return last_tuple[1].dtype
class ModelMixin(torch.nn.Module, PushToHubMixin): class ModelMixin(torch.nn.Module, PushToHubMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment