Unverified Commit 15001306 authored by Wadim Korablin's avatar Wadim Korablin Committed by GitHub
Browse files

Support for manual CLIP loading in StableDiffusionPipeline - txt2img. (#3832)



* Support for manual CLIP loading in StableDiffusionPipeline - txt2img.

* Update src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

* Update variables & according docs to match previous style.

* Updated to match style & quality of 'diffusers'

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 219636f7
......@@ -1339,6 +1339,17 @@ class FromCkptMixin:
"ddim"]`.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not.
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
needed.
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
An instance of
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
itself, if needed.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
......@@ -1383,6 +1394,8 @@ class FromCkptMixin:
upcast_attention = kwargs.pop("upcast_attention", None)
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
tokenizer = kwargs.pop("tokenizer", None)
torch_dtype = kwargs.pop("torch_dtype", None)
......@@ -1463,6 +1476,8 @@ class FromCkptMixin:
upcast_attention=upcast_attention,
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
if torch_dtype is not None:
......
......@@ -734,8 +734,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
return hf_model
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
text_model = (
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if text_encoder is None
else text_encoder
)
keys = list(checkpoint.keys())
......@@ -1025,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt(
load_safety_checker: bool = True,
pipeline_class: DiffusionPipeline = None,
local_files_only=False,
text_encoder=None,
tokenizer=None,
) -> DiffusionPipeline:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
......@@ -1072,6 +1078,15 @@ def download_from_original_stable_diffusion_ckpt(
The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
An instance of
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
needed.
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
......@@ -1327,8 +1342,10 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor=feature_extractor,
)
elif model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_model = convert_ldm_clip_checkpoint(
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer
if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment