Unverified Commit 4b9f5895 authored by Yuta Hayashibe's avatar Yuta Hayashibe Committed by GitHub
Browse files

Add --pretrained_model_name_revision option to train_dreambooth.py (#933)

* Add --pretrained_model_name_revision option to train_dreambooth.py

* Renamed --pretrained_model_name_revision to --revision
parent e2243de5
...@@ -35,6 +35,13 @@ def parse_args(): ...@@ -35,6 +35,13 @@ def parse_args():
required=True, required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.", help="Path to pretrained model or model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -344,7 +351,10 @@ def main(): ...@@ -344,7 +351,10 @@ def main():
if cur_class_images < args.num_class_images: if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
) )
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -390,14 +400,33 @@ def main(): ...@@ -390,14 +400,33 @@ def main():
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
)
elif args.pretrained_model_name_or_path: elif args.pretrained_model_name_or_path:
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") text_encoder = CLIPTextModel.from_pretrained(
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") args.pretrained_model_name_or_path,
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
vae.requires_grad_(False) vae.requires_grad_(False)
if not args.train_text_encoder: if not args.train_text_encoder:
...@@ -613,6 +642,7 @@ def main(): ...@@ -613,6 +642,7 @@ def main():
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment