Unverified Commit 8efd9ce7 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Chore] clean residue from copy-pasting in the UNet single file loader (#7295)

clean residue from copy-pasting
parent 299c16d0
...@@ -905,14 +905,14 @@ class UNet2DConditionLoadersMixin: ...@@ -905,14 +905,14 @@ class UNet2DConditionLoadersMixin:
class FromOriginalUNetMixin: class FromOriginalUNetMixin:
""" """
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
""" """
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs): def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r""" r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters: Parameters:
...@@ -951,6 +951,10 @@ class FromOriginalUNetMixin: ...@@ -951,6 +951,10 @@ class FromOriginalUNetMixin:
Can be used to overwrite load and saveable variables of the model. Can be used to overwrite load and saveable variables of the model.
""" """
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
...@@ -961,10 +965,6 @@ class FromOriginalUNetMixin: ...@@ -961,10 +965,6 @@ class FromOriginalUNetMixin:
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
checkpoint = load_single_file_model_checkpoint( checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path, pretrained_model_link_or_path,
resume_download=resume_download, resume_download=resume_download,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment