Unverified Commit edd78804 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[HiDream LoRA] optimizations + small updates (#11381)



* 1. add pre-computation of prompt embeddings when custom prompts are used as well
2. save model card even if model is not pushed to hub
3. remove scheduler initialization from code example - not necessary anymore (it's now if the base model's config)
4. add skip_final_inference - to allow to run with validation, but skip the final loading of the pipeline with the lora weights to reduce memory reqs

* pre encode validation prompt as well

* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* pre encode validation prompt as well

* Apply style fixes

* empty commit

* change default trained modules

* empty commit

* address comments + change encoding of validation prompt (before it was only pre-encoded if custom prompts are provided, but should be pre-encoded either way)

* Apply style fixes

* empty commit

* fix validation_embeddings definition

* fix final inference condition

* fix pipeline deletion in last inference

* Apply style fixes

* empty commit

* layers

* remove readme remarks on only pre-computing when instance prompt is provided and change example to 3d icons

* smol fix

* empty commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent b4be4228
...@@ -51,22 +51,9 @@ When running `accelerate config`, if we specify torch compile mode to True there ...@@ -51,22 +51,9 @@ When running `accelerate config`, if we specify torch compile mode to True there
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### Dog toy example ### 3d icon example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
...@@ -74,31 +61,31 @@ Now, we can launch training using: ...@@ -74,31 +61,31 @@ Now, we can launch training using:
> [!NOTE] > [!NOTE]
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing, > The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
> 8-bit Adam optimizer, latent caching, offloading, no validation. > 8-bit Adam optimizer, latent caching, offloading, no validation.
> Additionally, when provided with 'instance_prompt' only and no 'caption_column' (used for custom prompts for each image) > all text embeddings are pre-computed to save memory.
> text embeddings are pre-computed to save memory.
```bash ```bash
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev" export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="dog" export INSTANCE_DIR="linoyts/3d_icon"
export OUTPUT_DIR="trained-hidream-lora" export OUTPUT_DIR="trained-hidream-lora"
accelerate launch train_dreambooth_lora_hidream.py \ accelerate launch train_dreambooth_lora_hidream.py \
--pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \ --dataset_name=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \ --output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \ --mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \ --instance_prompt="3d icon" \
--caption_column="prompt"\
--validation_prompt="a 3dicon, a llama eating ramen" \
--resolution=1024 \ --resolution=1024 \
--train_batch_size=1 \ --train_batch_size=1 \
--gradient_accumulation_steps=4 \ --gradient_accumulation_steps=4 \
--use_8bit_adam \ --use_8bit_adam \
--rank=16 \ --rank=8 \
--learning_rate=2e-4 \ --learning_rate=2e-4 \
--report_to="wandb" \ --report_to="wandb" \
--lr_scheduler="constant" \ --lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=0 \ --lr_warmup_steps=100 \
--max_train_steps=1000 \ --max_train_steps=1000 \
--cache_latents \ --cache_latents\
--gradient_checkpointing \ --gradient_checkpointing \
--validation_epochs=25 \ --validation_epochs=25 \
--seed="0" \ --seed="0" \
...@@ -128,6 +115,5 @@ We provide several options for optimizing memory optimization: ...@@ -128,6 +115,5 @@ We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used. * `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done. * `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library. * `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
* `--instance_prompt` and no `--caption_column`: when only an instance prompt is provided, we will pre-compute the text embeddings and remove the text encoders from memory once done.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model. Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
...@@ -120,11 +120,7 @@ You should use `{instance_prompt}` to trigger the image generation. ...@@ -120,11 +120,7 @@ You should use `{instance_prompt}` to trigger the image generation.
```py ```py
>>> import torch >>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline >>> from diffusers import HiDreamImagePipeline
>>> scheduler = UniPCMultistepScheduler(
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
... )
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained( >>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
...@@ -136,7 +132,6 @@ You should use `{instance_prompt}` to trigger the image generation. ...@@ -136,7 +132,6 @@ You should use `{instance_prompt}` to trigger the image generation.
>>> pipe = HiDreamImagePipeline.from_pretrained( >>> pipe = HiDreamImagePipeline.from_pretrained(
... "HiDream-ai/HiDream-I1-Full", ... "HiDream-ai/HiDream-I1-Full",
... scheduler=scheduler,
... tokenizer_4=tokenizer_4, ... tokenizer_4=tokenizer_4,
... text_encoder_4=text_encoder_4, ... text_encoder_4=text_encoder_4,
... torch_dtype=torch.bfloat16, ... torch_dtype=torch.bfloat16,
...@@ -201,6 +196,7 @@ def log_validation( ...@@ -201,6 +196,7 @@ def log_validation(
torch_dtype, torch_dtype,
is_final_validation=False, is_final_validation=False,
): ):
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
...@@ -212,28 +208,16 @@ def log_validation( ...@@ -212,28 +208,16 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
(
prompt_embeds_t5,
negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
pipeline_args["prompt"],
)
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with autocast_ctx: with autocast_ctx:
image = pipeline( image = pipeline(
prompt_embeds_t5=prompt_embeds_t5, prompt_embeds_t5=pipeline_args["prompt_embeds_t5"],
prompt_embeds_llama3=prompt_embeds_llama3, prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"],
negative_prompt_embeds_t5=negative_prompt_embeds_t5, negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"],
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"],
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"],
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"],
generator=generator, generator=generator,
).images[0] ).images[0]
images.append(image) images.append(image)
...@@ -252,9 +236,9 @@ def log_validation( ...@@ -252,9 +236,9 @@ def log_validation(
} }
) )
pipeline.to("cpu")
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
return images return images
...@@ -392,6 +376,14 @@ def parse_args(input_args=None): ...@@ -392,6 +376,14 @@ def parse_args(input_args=None):
default=None, default=None,
help="A prompt that is used during validation to verify that the model is learning.", help="A prompt that is used during validation to verify that the model is learning.",
) )
parser.add_argument(
"--skip_final_inference",
default=False,
action="store_true",
help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.",
)
parser.add_argument( parser.add_argument(
"--final_validation_prompt", "--final_validation_prompt",
type=str, type=str,
...@@ -1016,6 +1008,7 @@ def main(args): ...@@ -1016,6 +1008,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename) image.save(image_filename)
pipeline.to("cpu")
del pipeline del pipeline
free_memory() free_memory()
...@@ -1140,7 +1133,7 @@ def main(args): ...@@ -1140,7 +1133,7 @@ def main(args):
if args.lora_layers is not None: if args.lora_layers is not None:
target_modules = [layer.strip() for layer in args.lora_layers.split(",")] target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
else: else:
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] target_modules = ["to_k", "to_q", "to_v", "to_out"]
# now we will add new LoRA weights the transformer layers # now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig( transformer_lora_config = LoraConfig(
...@@ -1314,42 +1307,65 @@ def main(args): ...@@ -1314,42 +1307,65 @@ def main(args):
) )
def compute_text_embeddings(prompt, text_encoding_pipeline): def compute_text_embeddings(prompt, text_encoding_pipeline):
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad(): with torch.no_grad():
t5_prompt_embeds, _, llama3_prompt_embeds, _, pooled_prompt_embeds, _ = ( (
text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length) t5_prompt_embeds,
negative_prompt_embeds_t5,
llama3_prompt_embeds,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length)
return (
t5_prompt_embeds,
llama3_prompt_embeds,
pooled_prompt_embeds,
negative_prompt_embeds_t5,
negative_prompt_embeds_llama3,
negative_pooled_prompt_embeds,
) )
if args.offload: # back to cpu
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
return t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding. # the redundant encoding.
if not train_dataset.custom_instance_prompts: if not train_dataset.custom_instance_prompts:
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
( (
instance_prompt_hidden_states_t5, instance_prompt_hidden_states_t5,
instance_prompt_hidden_states_llama3, instance_prompt_hidden_states_llama3,
instance_pooled_prompt_embeds, instance_pooled_prompt_embeds,
_,
_,
_,
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Handle class prompt for prior-preservation. # Handle class prompt for prior-preservation.
if args.with_prior_preservation: if args.with_prior_preservation:
( if args.offload:
class_prompt_hidden_states_t5, text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
class_prompt_hidden_states_llama3, (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
class_pooled_prompt_embeds, compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
) = compute_text_embeddings(args.class_prompt, text_encoding_pipeline) )
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# Clear the memory here validation_embeddings = {}
if not train_dataset.custom_instance_prompts: if args.validation_prompt is not None:
# delete tokenizers and text ecnoders except for llama (tokenizer & te four) if args.offload:
# as it's needed for inference with pipeline text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two, tokenizer_three (
if not args.validation_prompt: validation_embeddings["prompt_embeds_t5"],
del tokenizer_four, text_encoder_four validation_embeddings["prompt_embeds_llama3"],
free_memory() validation_embeddings["pooled_prompt_embeds"],
validation_embeddings["negative_prompt_embeds_t5"],
validation_embeddings["negative_prompt_embeds_llama3"],
validation_embeddings["negative_pooled_prompt_embeds"],
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't # pack the statically computed variables appropriately here. This is so that we don't
...@@ -1367,19 +1383,51 @@ def main(args): ...@@ -1367,19 +1383,51 @@ def main(args):
vae_config_scaling_factor = vae.config.scaling_factor vae_config_scaling_factor = vae.config.scaling_factor
vae_config_shift_factor = vae.config.shift_factor vae_config_shift_factor = vae.config.shift_factor
if args.cache_latents:
# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
if precompute_latents:
t5_prompt_cache = []
llama3_prompt_cache = []
pooled_prompt_cache = []
latents_cache = [] latents_cache = []
if args.offload: if args.offload:
vae = vae.to(accelerator.device) vae = vae.to(accelerator.device)
for batch in tqdm(train_dataloader, desc="Caching latents"): for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad(): with torch.no_grad():
if args.cache_latents:
batch["pixel_values"] = batch["pixel_values"].to( batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype accelerator.device, non_blocking=True, dtype=vae.dtype
) )
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if train_dataset.custom_instance_prompts:
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds, _, _, _ = compute_text_embeddings(
batch["prompts"], text_encoding_pipeline
)
t5_prompt_cache.append(t5_prompt_embeds)
llama3_prompt_cache.append(llama3_prompt_embeds)
pooled_prompt_cache.append(pooled_prompt_embeds)
if args.validation_prompt is None: # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if args.offload or args.cache_latents:
vae = vae.to("cpu")
if args.cache_latents:
del vae del vae
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
del (
text_encoder_one,
text_encoder_two,
text_encoder_three,
text_encoder_four,
tokenizer_two,
tokenizer_three,
tokenizer_four,
text_encoding_pipeline,
)
free_memory() free_memory()
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
...@@ -1487,9 +1535,9 @@ def main(args): ...@@ -1487,9 +1535,9 @@ def main(args):
with accelerator.accumulate(models_to_accumulate): with accelerator.accumulate(models_to_accumulate):
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts: if train_dataset.custom_instance_prompts:
t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( t5_prompt_embeds = t5_prompt_cache[step]
prompts, text_encoding_pipeline llama3_prompt_embeds = llama3_prompt_cache[step]
) pooled_prompt_embeds = pooled_prompt_cache[step]
else: else:
t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1) t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1)
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1) llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1)
...@@ -1619,26 +1667,30 @@ def main(args): ...@@ -1619,26 +1667,30 @@ def main(args):
# create pipeline # create pipeline
pipeline = HiDreamImagePipeline.from_pretrained( pipeline = HiDreamImagePipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
tokenizer_4=tokenizer_four, tokenizer=None,
text_encoder_4=text_encoder_four, text_encoder=None,
tokenizer_2=None,
text_encoder_2=None,
tokenizer_3=None,
text_encoder_3=None,
tokenizer_4=None,
text_encoder_4=None,
transformer=accelerator.unwrap_model(transformer), transformer=accelerator.unwrap_model(transformer),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation( images = log_validation(
pipeline=pipeline, pipeline=pipeline,
args=args, args=args,
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=validation_embeddings,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
epoch=epoch, epoch=epoch,
) )
free_memory()
images = None
del pipeline del pipeline
images = None
free_memory()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
...@@ -1655,20 +1707,22 @@ def main(args): ...@@ -1655,20 +1707,22 @@ def main(args):
transformer_lora_layers=transformer_lora_layers, transformer_lora_layers=transformer_lora_layers,
) )
images = []
run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)
should_run_final_inference = not args.skip_final_inference and run_validation
if should_run_final_inference:
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
tokenizer_4 = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_4_name_or_path)
tokenizer_4.pad_token = tokenizer_4.eos_token
text_encoder_4 = LlamaForCausalLM.from_pretrained(
args.pretrained_text_encoder_4_name_or_path,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
pipeline = HiDreamImagePipeline.from_pretrained( pipeline = HiDreamImagePipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
tokenizer_4=tokenizer_4, tokenizer=None,
text_encoder_4=text_encoder_4, text_encoder=None,
tokenizer_2=None,
text_encoder_2=None,
tokenizer_3=None,
text_encoder_3=None,
tokenizer_4=None,
text_encoder_4=None,
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -1677,28 +1731,25 @@ def main(args): ...@@ -1677,28 +1731,25 @@ def main(args):
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# run inference # run inference
images = []
if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):
prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
args.num_validation_images = args.num_validation_images if args.num_validation_images else 1
pipeline_args = {"prompt": prompt_to_use, "num_images_per_prompt": args.num_validation_images}
images = log_validation( images = log_validation(
pipeline=pipeline, pipeline=pipeline,
args=args, args=args,
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=validation_embeddings,
epoch=epoch, epoch=epoch,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
del pipeline
free_memory()
validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt
save_model_card( save_model_card(
(args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
validation_prompt=validation_prpmpt, validation_prompt=validation_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
...@@ -1711,7 +1762,6 @@ def main(args): ...@@ -1711,7 +1762,6 @@ def main(args):
) )
images = None images = None
del pipeline
accelerator.end_training() accelerator.end_training()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment