Unverified Commit 6d2e19f7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Examples] Allow downloading variant model files (#5531)



* add variant

* add variant

* Apply suggestions from code review

* reformat

* fix: textual_inversion.py

* fix: variant in model_info

---------
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
parent 2a7f43a7
...@@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ...@@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
controlnet=controlnet, controlnet=controlnet,
safety_checker=None, safety_checker=None,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
...@@ -249,10 +250,13 @@ def parse_args(input_args=None): ...@@ -249,10 +250,13 @@ def parse_args(input_args=None):
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=( help="Revision of pretrained model identifier from huggingface.co/models.",
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" )
" float32 precision." parser.add_argument(
), "--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
) )
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
...@@ -767,11 +771,13 @@ def main(args): ...@@ -767,11 +771,13 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
......
...@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ...@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
unet=unet, unet=unet,
controlnet=controlnet, controlnet=controlnet,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
...@@ -243,15 +244,18 @@ def parse_args(input_args=None): ...@@ -243,15 +244,18 @@ def parse_args(input_args=None):
help="Path to pretrained controlnet model or model identifier from huggingface.co/models." help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.", " If not specified controlnet weights are initialized from unet.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=( help="Revision of pretrained model identifier from huggingface.co/models.",
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
) )
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
...@@ -793,10 +797,16 @@ def main(args): ...@@ -793,10 +797,16 @@ def main(args):
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
) )
tokenizer_two = AutoTokenizer.from_pretrained( tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
) )
# import correct text encoder classes # import correct text encoder classes
...@@ -810,10 +820,10 @@ def main(args): ...@@ -810,10 +820,10 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
vae_path = ( vae_path = (
args.pretrained_model_name_or_path args.pretrained_model_name_or_path
...@@ -824,9 +834,10 @@ def main(args): ...@@ -824,9 +834,10 @@ def main(args):
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
......
...@@ -332,6 +332,12 @@ def parse_args(input_args=None): ...@@ -332,6 +332,12 @@ def parse_args(input_args=None):
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -740,6 +746,7 @@ def main(args): ...@@ -740,6 +746,7 @@ def main(args):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
safety_checker=None, safety_checker=None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -801,11 +808,13 @@ def main(args): ...@@ -801,11 +808,13 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# Adding a modifier token which is optimized #### # Adding a modifier token which is optimized ####
...@@ -1229,6 +1238,7 @@ def main(args): ...@@ -1229,6 +1238,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer, tokenizer=tokenizer,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
...@@ -1278,7 +1288,7 @@ def main(args): ...@@ -1278,7 +1288,7 @@ def main(args):
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
......
...@@ -139,6 +139,7 @@ def log_validation( ...@@ -139,6 +139,7 @@ def log_validation(
text_encoder=text_encoder, text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
**pipeline_args, **pipeline_args,
) )
...@@ -239,10 +240,13 @@ def parse_args(input_args=None): ...@@ -239,10 +240,13 @@ def parse_args(input_args=None):
type=str, type=str,
default=None, default=None,
required=False, required=False,
help=( help="Revision of pretrained model identifier from huggingface.co/models.",
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" )
" float32 precision." parser.add_argument(
), "--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
) )
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
...@@ -859,6 +863,7 @@ def main(args): ...@@ -859,6 +863,7 @@ def main(args):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
safety_checker=None, safety_checker=None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -912,18 +917,18 @@ def main(args): ...@@ -912,18 +917,18 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
if model_has_vae(args): if model_has_vae(args):
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
) )
else: else:
vae = None vae = None
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...@@ -1379,6 +1384,7 @@ def main(args): ...@@ -1379,6 +1384,7 @@ def main(args):
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
**pipeline_args, **pipeline_args,
) )
......
...@@ -460,7 +460,10 @@ def main(): ...@@ -460,7 +460,10 @@ def main():
# Load models and create wrapper for stable diffusion # Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained( text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision args.pretrained_model_name_or_path,
subfolder="text_encoder",
dtype=weight_dtype,
revision=args.revision,
) )
vae, vae_params = FlaxAutoencoderKL.from_pretrained( vae, vae_params = FlaxAutoencoderKL.from_pretrained(
vae_arg, vae_arg,
...@@ -468,7 +471,10 @@ def main(): ...@@ -468,7 +471,10 @@ def main():
**vae_kwargs, **vae_kwargs,
) )
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision args.pretrained_model_name_or_path,
subfolder="unet",
dtype=weight_dtype,
revision=args.revision,
) )
# Optimization # Optimization
......
...@@ -183,6 +183,12 @@ def parse_args(input_args=None): ...@@ -183,6 +183,12 @@ def parse_args(input_args=None):
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -750,6 +756,7 @@ def main(args): ...@@ -750,6 +756,7 @@ def main(args):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
safety_checker=None, safety_checker=None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -803,11 +810,11 @@ def main(args): ...@@ -803,11 +810,11 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained( text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
try: try:
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
) )
except OSError: except OSError:
# IF does not have a VAE so let's just set it to None # IF does not have a VAE so let's just set it to None
...@@ -815,7 +822,7 @@ def main(args): ...@@ -815,7 +822,7 @@ def main(args):
vae = None vae = None
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# We only train the additional adapter LoRA layers # We only train the additional adapter LoRA layers
...@@ -1310,6 +1317,7 @@ def main(args): ...@@ -1310,6 +1317,7 @@ def main(args):
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
...@@ -1395,7 +1403,7 @@ def main(args): ...@@ -1395,7 +1403,7 @@ def main(args):
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
......
...@@ -204,6 +204,12 @@ def parse_args(input_args=None): ...@@ -204,6 +204,12 @@ def parse_args(input_args=None):
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -877,6 +883,7 @@ def main(args): ...@@ -877,6 +883,7 @@ def main(args):
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -915,10 +922,16 @@ def main(args): ...@@ -915,10 +922,16 @@ def main(args):
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
) )
tokenizer_two = AutoTokenizer.from_pretrained( tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
) )
# import correct text encoder classes # import correct text encoder classes
...@@ -932,10 +945,10 @@ def main(args): ...@@ -932,10 +945,10 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
vae_path = ( vae_path = (
args.pretrained_model_name_or_path args.pretrained_model_name_or_path
...@@ -943,10 +956,13 @@ def main(args): ...@@ -943,10 +956,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path else args.pretrained_vae_model_name_or_path
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# We only train the additional adapter LoRA layers # We only train the additional adapter LoRA layers
...@@ -1571,10 +1587,16 @@ def main(args): ...@@ -1571,10 +1587,16 @@ def main(args):
# create pipeline # create pipeline
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
) )
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
...@@ -1583,6 +1605,7 @@ def main(args): ...@@ -1583,6 +1605,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
...@@ -1660,10 +1683,15 @@ def main(args): ...@@ -1660,10 +1683,15 @@ def main(args):
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
......
...@@ -78,6 +78,12 @@ def parse_args(): ...@@ -78,6 +78,12 @@ def parse_args():
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -435,9 +441,11 @@ def main(): ...@@ -435,9 +441,11 @@ def main():
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
) )
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
) )
...@@ -915,6 +923,7 @@ def main(): ...@@ -915,6 +923,7 @@ def main():
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae), vae=accelerator.unwrap_model(vae),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
...@@ -966,6 +975,7 @@ def main(): ...@@ -966,6 +975,7 @@ def main():
vae=accelerator.unwrap_model(vae), vae=accelerator.unwrap_model(vae),
unet=unet, unet=unet,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
......
...@@ -118,6 +118,12 @@ def parse_args(): ...@@ -118,6 +118,12 @@ def parse_args():
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -484,9 +490,10 @@ def main(): ...@@ -484,9 +490,10 @@ def main():
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# InstructPix2Pix uses an additional image for conditioning. To accommodate that, # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
...@@ -695,10 +702,16 @@ def main(): ...@@ -695,10 +702,16 @@ def main():
# Load scheduler, tokenizer and models. # Load scheduler, tokenizer and models.
tokenizer_1 = AutoTokenizer.from_pretrained( tokenizer_1 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
) )
tokenizer_2 = AutoTokenizer.from_pretrained( tokenizer_2 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
) )
text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
text_encoder_cls_2 = import_model_class_from_model_name_or_path( text_encoder_cls_2 = import_model_class_from_model_name_or_path(
...@@ -708,10 +721,10 @@ def main(): ...@@ -708,10 +721,10 @@ def main():
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_1 = text_encoder_cls_1.from_pretrained( text_encoder_1 = text_encoder_cls_1.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_2 = text_encoder_cls_2.from_pretrained( text_encoder_2 = text_encoder_cls_2.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL # We ALWAYS pre-compute the additional condition embeddings needed for SDXL
...@@ -1109,6 +1122,7 @@ def main(): ...@@ -1109,6 +1122,7 @@ def main():
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
vae=vae, vae=vae,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
...@@ -1176,6 +1190,7 @@ def main(): ...@@ -1176,6 +1190,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
......
...@@ -85,6 +85,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step): ...@@ -85,6 +85,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
unet=unet, unet=unet,
adapter=adapter, adapter=adapter,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
...@@ -262,6 +263,12 @@ def parse_args(input_args=None): ...@@ -262,6 +263,12 @@ def parse_args(input_args=None):
" float32 precision." " float32 precision."
), ),
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -812,10 +819,16 @@ def main(args): ...@@ -812,10 +819,16 @@ def main(args):
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
) )
tokenizer_two = AutoTokenizer.from_pretrained( tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
) )
# import correct text encoder classes # import correct text encoder classes
...@@ -829,10 +842,10 @@ def main(args): ...@@ -829,10 +842,10 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
vae_path = ( vae_path = (
args.pretrained_model_name_or_path args.pretrained_model_name_or_path
...@@ -843,9 +856,10 @@ def main(args): ...@@ -843,9 +856,10 @@ def main(args):
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
if args.adapter_model_name_or_path: if args.adapter_model_name_or_path:
......
...@@ -148,6 +148,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight ...@@ -148,6 +148,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
safety_checker=None, safety_checker=None,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
...@@ -209,6 +210,12 @@ def parse_args(): ...@@ -209,6 +210,12 @@ def parse_args():
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -567,10 +574,10 @@ def main(): ...@@ -567,10 +574,10 @@ def main():
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()): with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
...@@ -585,7 +592,7 @@ def main(): ...@@ -585,7 +592,7 @@ def main():
# Create EMA for the unet. # Create EMA for the unet.
if args.use_ema: if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained( ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
...@@ -1026,6 +1033,7 @@ def main(): ...@@ -1026,6 +1033,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
......
...@@ -54,6 +54,12 @@ def parse_args(): ...@@ -54,6 +54,12 @@ def parse_args():
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
......
...@@ -130,6 +130,12 @@ def parse_args(): ...@@ -130,6 +130,12 @@ def parse_args():
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -454,9 +460,11 @@ def main(): ...@@ -454,9 +460,11 @@ def main():
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# freeze parameters of models to save more memory # freeze parameters of models to save more memory
unet.requires_grad_(False) unet.requires_grad_(False)
...@@ -881,6 +889,7 @@ def main(): ...@@ -881,6 +889,7 @@ def main():
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
...@@ -937,7 +946,7 @@ def main(): ...@@ -937,7 +946,7 @@ def main():
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
......
...@@ -180,6 +180,12 @@ def parse_args(input_args=None): ...@@ -180,6 +180,12 @@ def parse_args(input_args=None):
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -570,10 +576,16 @@ def main(args): ...@@ -570,10 +576,16 @@ def main(args):
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
) )
tokenizer_two = AutoTokenizer.from_pretrained( tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
) )
# import correct text encoder classes # import correct text encoder classes
...@@ -587,10 +599,10 @@ def main(args): ...@@ -587,10 +599,10 @@ def main(args):
# Load scheduler and models # Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
vae_path = ( vae_path = (
args.pretrained_model_name_or_path args.pretrained_model_name_or_path
...@@ -598,10 +610,13 @@ def main(args): ...@@ -598,10 +610,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path else args.pretrained_vae_model_name_or_path
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# We only train the additional adapter LoRA layers # We only train the additional adapter LoRA layers
...@@ -1176,6 +1191,7 @@ def main(args): ...@@ -1176,6 +1191,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
...@@ -1241,7 +1257,11 @@ def main(args): ...@@ -1241,7 +1257,11 @@ def main(args):
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
......
...@@ -148,6 +148,12 @@ def parse_args(input_args=None): ...@@ -148,6 +148,12 @@ def parse_args(input_args=None):
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--dataset_name", "--dataset_name",
type=str, type=str,
...@@ -618,10 +624,16 @@ def main(args): ...@@ -618,10 +624,16 @@ def main(args):
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
) )
tokenizer_two = AutoTokenizer.from_pretrained( tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
) )
# import correct text encoder classes # import correct text encoder classes
...@@ -636,10 +648,10 @@ def main(args): ...@@ -636,10 +648,10 @@ def main(args):
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma # Check for terminal SNR in combination with SNR Gamma
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
) )
text_encoder_two = text_encoder_cls_two.from_pretrained( text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
) )
vae_path = ( vae_path = (
args.pretrained_model_name_or_path args.pretrained_model_name_or_path
...@@ -647,10 +659,13 @@ def main(args): ...@@ -647,10 +659,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path else args.pretrained_vae_model_name_or_path
) )
vae = AutoencoderKL.from_pretrained( vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
) )
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# Freeze vae and text encoders. # Freeze vae and text encoders.
...@@ -677,7 +692,7 @@ def main(args): ...@@ -677,7 +692,7 @@ def main(args):
# Create EMA for the unet. # Create EMA for the unet.
if args.use_ema: if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained( ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
...@@ -1145,12 +1160,14 @@ def main(args): ...@@ -1145,12 +1160,14 @@ def main(args):
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
) )
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
unet=accelerator.unwrap_model(unet), unet=accelerator.unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
if args.prediction_type is not None: if args.prediction_type is not None:
...@@ -1198,10 +1215,16 @@ def main(args): ...@@ -1198,10 +1215,16 @@ def main(args):
vae_path, vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype args.pretrained_model_name_or_path,
unet=unet,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
) )
if args.prediction_type is not None: if args.prediction_type is not None:
scheduler_args = {"prediction_type": args.prediction_type} scheduler_args = {"prediction_type": args.prediction_type}
......
...@@ -126,6 +126,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight ...@@ -126,6 +126,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
vae=vae, vae=vae,
safety_checker=None, safety_checker=None,
revision=args.revision, revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
...@@ -206,6 +207,12 @@ def parse_args(): ...@@ -206,6 +207,12 @@ def parse_args():
required=False, required=False,
help="Revision of pretrained model identifier from huggingface.co/models.", help="Revision of pretrained model identifier from huggingface.co/models.",
) )
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument( parser.add_argument(
"--tokenizer_name", "--tokenizer_name",
type=str, type=str,
...@@ -624,9 +631,11 @@ def main(): ...@@ -624,9 +631,11 @@ def main():
text_encoder = CLIPTextModel.from_pretrained( text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
) )
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
# Add the placeholder token in tokenizer # Add the placeholder token in tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment