Commit 372b5810 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix make style

parent 45171174
...@@ -900,7 +900,12 @@ def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): ...@@ -900,7 +900,12 @@ def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
def convert_open_clip_checkpoint( def convert_open_clip_checkpoint(
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, local_files_only=False, **config_kwargs checkpoint,
config_name,
prefix="cond_stage_model.model.",
has_projection=False,
local_files_only=False,
**config_kwargs,
): ):
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
# text_model = CLIPTextModelWithProjection.from_pretrained( # text_model = CLIPTextModelWithProjection.from_pretrained(
...@@ -989,13 +994,17 @@ def stable_unclip_image_encoder(original_config, local_files_only=False): ...@@ -989,13 +994,17 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
if clip_model_name == "ViT-L/14": if clip_model_name == "ViT-L/14":
feature_extractor = CLIPImageProcessor() feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
else: else:
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
feature_extractor = CLIPImageProcessor() feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only) image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
...@@ -1178,8 +1187,7 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1178,8 +1187,7 @@ def download_from_original_stable_diffusion_ckpt(
needed. needed.
config_files (`Dict[str, str]`, *optional*, defaults to `None`): config_files (`Dict[str, str]`, *optional*, defaults to `None`):
A dictionary mapping from config file names to their contents. If this parameter is `None`, the function A dictionary mapping from config file names to their contents. If this parameter is `None`, the function
will load the config files by itself, if needed. will load the config files by itself, if needed. Valid keys are:
Valid keys are:
- `v1`: Config file for Stable Diffusion v1 - `v1`: Config file for Stable Diffusion v1
- `v2`: Config file for Stable Diffusion v2 - `v2`: Config file for Stable Diffusion v2
- `xl`: Config file for Stable Diffusion XL - `xl`: Config file for Stable Diffusion XL
...@@ -1412,7 +1420,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1412,7 +1420,9 @@ def download_from_original_stable_diffusion_ckpt(
config_kwargs = {"subfolder": "text_encoder"} config_kwargs = {"subfolder": "text_encoder"}
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
)
if stable_unclip is None: if stable_unclip is None:
if controlnet: if controlnet:
...@@ -1464,12 +1474,20 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1464,12 +1474,20 @@ def download_from_original_stable_diffusion_ckpt(
elif stable_unclip == "txt2img": elif stable_unclip == "txt2img":
if stable_unclip_prior is None or stable_unclip_prior == "karlo": if stable_unclip_prior is None or stable_unclip_prior == "karlo":
karlo_model = "kakaobrain/karlo-v1-alpha" karlo_model = "kakaobrain/karlo-v1-alpha"
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior", local_files_only=local_files_only) prior = PriorTransformer.from_pretrained(
karlo_model, subfolder="prior", local_files_only=local_files_only
prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) )
prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
prior_tokenizer = CLIPTokenizer.from_pretrained(
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only) "openai/clip-vit-large-patch14", local_files_only=local_files_only
)
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
prior_scheduler = UnCLIPScheduler.from_pretrained(
karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only
)
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
else: else:
raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
...@@ -1496,7 +1514,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1496,7 +1514,9 @@ def download_from_original_stable_diffusion_ckpt(
elif model_type == "PaintByExample": elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint) vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only) feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
pipe = PaintByExamplePipeline( pipe = PaintByExamplePipeline(
vae=vae, vae=vae,
image_encoder=vision_model, image_encoder=vision_model,
...@@ -1509,11 +1529,19 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1509,11 +1529,19 @@ def download_from_original_stable_diffusion_ckpt(
text_model = convert_ldm_clip_checkpoint( text_model = convert_ldm_clip_checkpoint(
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
) )
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) if tokenizer is None else tokenizer tokenizer = (
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
if tokenizer is None
else tokenizer
)
if load_safety_checker: if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only) safety_checker = StableDiffusionSafetyChecker.from_pretrained(
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only) "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
else: else:
safety_checker = None safety_checker = None
feature_extractor = None feature_extractor = None
...@@ -1541,9 +1569,13 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1541,9 +1569,13 @@ def download_from_original_stable_diffusion_ckpt(
) )
elif model_type in ["SDXL", "SDXL-Refiner"]: elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL": if model_type == "SDXL":
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only) tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280} config_kwargs = {"projection_dim": 1280}
...@@ -1564,7 +1596,9 @@ def download_from_original_stable_diffusion_ckpt( ...@@ -1564,7 +1596,9 @@ def download_from_original_stable_diffusion_ckpt(
else: else:
tokenizer = None tokenizer = None
text_encoder = None text_encoder = None
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only) tokenizer_2 = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
)
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280} config_kwargs = {"projection_dim": 1280}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment