"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ba21735c42734b3225a612128d304cf3735a32f1"
Unverified Commit f2acfb67 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Remove hardcoded names from PT scripts (#1778)



* Remove hardcoded names from PT scripts

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 8aa4372a
...@@ -16,9 +16,8 @@ import PIL ...@@ -16,9 +16,8 @@ import PIL
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
...@@ -28,7 +27,7 @@ from packaging import version ...@@ -28,7 +27,7 @@ from packaging import version
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
...@@ -678,14 +677,12 @@ def main(): ...@@ -678,14 +677,12 @@ def main():
else: else:
save_full_model = not args.only_save_embeds save_full_model = not args.only_save_embeds
if save_full_model: if save_full_model:
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings # Save the newly trained embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment