Unverified Commit e185084a authored by Levi McCallum's avatar Levi McCallum Committed by GitHub
Browse files

Add variant argument to dreambooth lora sdxl advanced (#6021)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent b2172922
...@@ -225,6 +225,12 @@ def parse_args(input_args=None): ...@@ -225,6 +225,12 @@ def parse_args(input_args=None):
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -1064,6 +1070,7 @@ def main(args): ...@@ -1064,6 +1070,7 @@ def main(args):
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -1102,10 +1109,18 @@ def main(args): ...@@ -1102,10 +1109,18 @@ def main(args):
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
use_fast=False,
) )
tokenizer_two = AutoTokenizer.from_pretrained( tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
variant=args.variant,
use_fast=False,
) )
# import correct text encoder classes # import correct text encoder classes
...@@ -1119,10 +1134,10 @@ def main(args): ...@@ -1119,10 +1134,10 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
vae_path = ( vae_path = (
args.pretrained_model_name_or_path args.pretrained_model_name_or_path
...@@ -1130,10 +1145,13 @@ def main(args): ...@@ -1130,10 +1145,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path else args.pretrained_vae_model_name_or_path
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
...@@ -1843,10 +1861,16 @@ def main(args): ...@@ -1843,10 +1861,16 @@ def main(args):
# create pipeline # create pipeline
if freeze_text_encoder: if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
) )
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
...@@ -1855,6 +1879,7 @@ def main(args): ...@@ -1855,6 +1879,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
...@@ -1932,10 +1957,15 @@ def main(args): ...@@ -1932,10 +1957,15 @@ def main(args):
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment