Unverified Commit 23bc56a0 authored by xieofxie's avatar xieofxie Committed by GitHub
Browse files

add provider_options in from_pretrained (#10719)


Co-authored-by: default avatarhualxie <hualxie@microsoft.com>
parent 5b1dcd15
...@@ -630,6 +630,7 @@ def load_sub_model( ...@@ -630,6 +630,7 @@ def load_sub_model(
cached_folder: Union[str, os.PathLike], cached_folder: Union[str, os.PathLike],
use_safetensors: bool, use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]], dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
): ):
"""Helper method to load the module `name` from `library_name` and `class_name`""" """Helper method to load the module `name` from `library_name` and `class_name`"""
...@@ -676,6 +677,7 @@ def load_sub_model( ...@@ -676,6 +677,7 @@ def load_sub_model(
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options loading_kwargs["sess_options"] = sess_options
loading_kwargs["provider_options"] = provider_options
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
......
...@@ -677,6 +677,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -677,6 +677,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
custom_revision = kwargs.pop("custom_revision", None) custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None) provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None) sess_options = kwargs.pop("sess_options", None)
provider_options = kwargs.pop("provider_options", None)
device_map = kwargs.pop("device_map", None) device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None) max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None) offload_folder = kwargs.pop("offload_folder", None)
...@@ -971,6 +972,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -971,6 +972,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
cached_folder=cached_folder, cached_folder=cached_folder,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
dduf_entries=dduf_entries, dduf_entries=dduf_entries,
provider_options=provider_options,
) )
logger.info( logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment