Unverified Commit 71f34fc5 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Flux LoRA] fix issues in flux lora scripts (#11111)



* remove custom scheduler

* update requirements.txt

* log_validation with mixed precision

* add intermediate embeddings saving when checkpointing is enabled

* remove comment

* fix validation

* add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script

* revert unwrap_model change temp

* add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model

* changes to align advanced script with canonical script

* make changes for distributed training + unify unwrap_model calls in advanced script

* add module.dtype fix to dreambooth script

* unify unwrap_model calls in dreambooth script

* fix condition in validation run

* mixed precision

* Update examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* smol style change

* change autocast

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent c51b6bd8
accelerate>=0.16.0 accelerate>=0.31.0
torchvision torchvision
transformers>=4.25.1 transformers>=4.41.2
ftfy ftfy
tensorboard tensorboard
Jinja2 Jinja2
peft==0.7.0 peft>=0.11.1
\ No newline at end of file sentencepiece
\ No newline at end of file
...@@ -895,6 +895,9 @@ def _encode_prompt_with_t5( ...@@ -895,6 +895,9 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
...@@ -936,9 +939,13 @@ def _encode_prompt_with_clip( ...@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
...@@ -958,7 +965,12 @@ def encode_prompt( ...@@ -958,7 +965,12 @@ def encode_prompt(
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip( pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0], text_encoder=text_encoders[0],
...@@ -1590,7 +1602,7 @@ def main(args): ...@@ -1590,7 +1602,7 @@ def main(args):
) )
# handle guidance # handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds: if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1716,9 +1728,9 @@ def main(args): ...@@ -1716,9 +1728,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False), text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False), text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False), transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
......
...@@ -177,16 +177,25 @@ def log_validation( ...@@ -177,16 +177,25 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx: with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation" phase_name = "test" if is_final_validation else "validation"
...@@ -203,8 +212,7 @@ def log_validation( ...@@ -203,8 +212,7 @@ def log_validation(
) )
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
return images return images
...@@ -932,6 +940,9 @@ def _encode_prompt_with_t5( ...@@ -932,6 +940,9 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
...@@ -973,9 +984,13 @@ def _encode_prompt_with_clip( ...@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
...@@ -994,6 +1009,10 @@ def encode_prompt( ...@@ -994,6 +1009,10 @@ def encode_prompt(
text_input_ids_list=None, text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip( pooled_prompt_embeds = _encode_prompt_with_clip(
...@@ -1619,7 +1638,7 @@ def main(args): ...@@ -1619,7 +1638,7 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one.train() text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer] models_to_accumulate = [transformer]
...@@ -1710,7 +1729,7 @@ def main(args): ...@@ -1710,7 +1729,7 @@ def main(args):
) )
# handle guidance # handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds: if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1828,9 +1847,9 @@ def main(args): ...@@ -1828,9 +1847,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer), transformer=unwrap_model(transformer),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment