Unverified Commit c76e8840 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

update get_parameter_dtype (#9526)



* up

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent d9c96917
...@@ -93,14 +93,10 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device: ...@@ -93,14 +93,10 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
try: try:
params = tuple(parameter.parameters()) return next(parameter.parameters()).dtype
if len(params) > 0: except StopIteration:
return params[0].dtype try:
return next(parameter.buffers()).dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration: except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5 # For torch.nn.DataParallel compatibility in PyTorch 1.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment