Unverified Commit 12878229 authored by Eliseu Silva's avatar Eliseu Silva Committed by GitHub
Browse files

Fix for use_safetensors parameters, allow use of parameter on loading submodels (#9576) (#9587)

* Fix for use_safetensors parameters, allow use of parameter on loading submodels (#9576)
parent a80f6892
...@@ -601,6 +601,7 @@ def load_sub_model( ...@@ -601,6 +601,7 @@ def load_sub_model(
variant: str, variant: str,
low_cpu_mem_usage: bool, low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike], cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
): ):
"""Helper method to load the module `name` from `library_name` and `class_name`""" """Helper method to load the module `name` from `library_name` and `class_name`"""
...@@ -670,6 +671,7 @@ def load_sub_model( ...@@ -670,6 +671,7 @@ def load_sub_model(
loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_folder"] = offload_folder
loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["offload_state_dict"] = offload_state_dict
loading_kwargs["variant"] = model_variants.pop(name, None) loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if from_flax: if from_flax:
loading_kwargs["from_flax"] = True loading_kwargs["from_flax"] = True
......
...@@ -905,6 +905,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -905,6 +905,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant=variant, variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder, cached_folder=cached_folder,
use_safetensors=use_safetensors,
) )
logger.info( logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment