Unverified Commit 15001306 authored by Wadim Korablin's avatar Wadim Korablin Committed by GitHub
Browse files

Support for manual CLIP loading in StableDiffusionPipeline - txt2img. (#3832)



* Support for manual CLIP loading in StableDiffusionPipeline - txt2img.

* Update src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

* Update variables & according docs to match previous style.

* Updated to match style & quality of 'diffusers'

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 219636f7
...@@ -1339,6 +1339,17 @@ class FromCkptMixin: ...@@ -1339,6 +1339,17 @@ class FromCkptMixin:
"ddim"]`. "ddim"]`.
load_safety_checker (`bool`, *optional*, defaults to `True`): load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Whether to load the safety checker or not.
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
needed.
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
An instance of
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
itself, if needed.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
...@@ -1383,6 +1394,8 @@ class FromCkptMixin: ...@@ -1383,6 +1394,8 @@ class FromCkptMixin:
upcast_attention = kwargs.pop("upcast_attention", None) upcast_attention = kwargs.pop("upcast_attention", None)
load_safety_checker = kwargs.pop("load_safety_checker", True) load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None) prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
tokenizer = kwargs.pop("tokenizer", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
...@@ -1463,6 +1476,8 @@ class FromCkptMixin: ...@@ -1463,6 +1476,8 @@ class FromCkptMixin:
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
load_safety_checker=load_safety_checker, load_safety_checker=load_safety_checker,
prediction_type=prediction_type, prediction_type=prediction_type,
text_encoder=text_encoder,
tokenizer=tokenizer,
) )
if torch_dtype is not None: if torch_dtype is not None:
......
...@@ -734,8 +734,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config): ...@@ -734,8 +734,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
return hf_model return hf_model
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False): def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) text_model = (
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if text_encoder is None
else text_encoder
)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
...@@ -1025,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1025,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt(
load_safety_checker: bool = True, load_safety_checker: bool = True,
pipeline_class: DiffusionPipeline = None, pipeline_class: DiffusionPipeline = None,
local_files_only=False, local_files_only=False,
text_encoder=None,
tokenizer=None,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
""" """
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
...@@ -1072,6 +1078,15 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1072,6 +1078,15 @@ def download_from_original_stable_diffusion_ckpt(
The pipeline class to use. Pass `None` to determine automatically. The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model). Whether or not to only look at local files (i.e., do not try to download the model).
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
An instance of
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
needed.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
""" """
...@@ -1327,8 +1342,10 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1327,8 +1342,10 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
elif model_type == "FrozenCLIPEmbedder": elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) text_model = convert_ldm_clip_checkpoint(
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer
if load_safety_checker: if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment