Unverified Commit d1bcbf38 authored by allo-'s avatar allo- Committed by GitHub
Browse files

[textual_inversion] Add an option for only saving the embeddings (#781)

[textual_inversion] Add an option to only save embeddings

Add an command line option --only_save_embeds to the example script, for
not saving the full model. Then only the learned embeddings are saved,
which can be added to the original model at runtime in a similar way as
they are created in the training script.
Saving the full model is forced when --push_to_hub is used. (Implements #759)
parent df7cd5fe
...@@ -16,8 +16,9 @@ import PIL ...@@ -16,8 +16,9 @@ import PIL
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
...@@ -25,7 +26,7 @@ from packaging import version ...@@ -25,7 +26,7 @@ from packaging import version
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
...@@ -65,6 +66,12 @@ def parse_args(): ...@@ -65,6 +66,12 @@ def parse_args():
default=500, default=500,
help="Save learned_embeds.bin every X updates steps.", help="Save learned_embeds.bin every X updates steps.",
) )
parser.add_argument(
"--only_save_embeds",
action="store_true",
default=False,
help="Save only the embeddings for the new concept.",
)
parser.add_argument( parser.add_argument(
"--pretrained_model_name_or_path", "--pretrained_model_name_or_path",
type=str, type=str,
...@@ -596,16 +603,23 @@ def main(): ...@@ -596,16 +603,23 @@ def main():
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process: if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained( if args.push_to_hub and args.only_save_embeds:
args.pretrained_model_name_or_path, logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
text_encoder=accelerator.unwrap_model(text_encoder), save_full_model = True
tokenizer=tokenizer, else:
vae=vae, save_full_model = not args.only_save_embeds
unet=unet, if save_full_model:
revision=args.revision, pipeline = StableDiffusionPipeline(
) text_encoder=accelerator.unwrap_model(text_encoder),
pipeline.save_pretrained(args.output_dir) vae=vae,
# Also save the newly trained embeddings unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin") save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment