Unverified Commit 71f34fc5 authored by Linoy Tsaban's avatar Linoy Tsaban Committed by GitHub
Browse files

[Flux LoRA] fix issues in flux lora scripts (#11111)



* remove custom scheduler

* update requirements.txt

* log_validation with mixed precision

* add intermediate embeddings saving when checkpointing is enabled

* remove comment

* fix validation

* add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script

* revert unwrap_model change temp

* add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model

* changes to align advanced script with canonical script

* make changes for distributed training + unify unwrap_model calls in advanced script

* add module.dtype fix to dreambooth script

* unify unwrap_model calls in dreambooth script

* fix condition in validation run

* mixed precision

* Update examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* smol style change

* change autocast

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent c51b6bd8
accelerate>=0.16.0 accelerate>=0.31.0
torchvision torchvision
transformers>=4.25.1 transformers>=4.41.2
ftfy ftfy
tensorboard tensorboard
Jinja2 Jinja2
peft==0.7.0 peft>=0.11.1
\ No newline at end of file sentencepiece
\ No newline at end of file
...@@ -24,7 +24,7 @@ import re ...@@ -24,7 +24,7 @@ import re
import shutil import shutil
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -228,10 +228,20 @@ def log_validation( ...@@ -228,10 +228,20 @@ def log_validation(
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext() autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx: # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation" phase_name = "test" if is_final_validation else "validation"
...@@ -657,6 +667,7 @@ def parse_args(input_args=None): ...@@ -657,6 +667,7 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
) )
parser.add_argument( parser.add_argument(
"--lora_layers", "--lora_layers",
type=str, type=str,
...@@ -666,6 +677,7 @@ def parse_args(input_args=None): ...@@ -666,6 +677,7 @@ def parse_args(input_args=None):
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
), ),
) )
parser.add_argument( parser.add_argument(
"--adam_epsilon", "--adam_epsilon",
type=float, type=float,
...@@ -738,6 +750,15 @@ def parse_args(input_args=None): ...@@ -738,6 +750,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
), ),
) )
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument( parser.add_argument(
"--prior_generation_precision", "--prior_generation_precision",
type=str, type=str,
...@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F ...@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
return text_input_ids return text_input_ids
def _get_t5_prompt_embeds( def _encode_prompt_with_t5(
text_encoder, text_encoder,
tokenizer, tokenizer,
max_sequence_length=512, max_sequence_length=512,
...@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds( ...@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds(
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds( ...@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds(
return prompt_embeds return prompt_embeds
def _get_clip_prompt_embeds( def _encode_prompt_with_clip(
text_encoder, text_encoder,
tokenizer, tokenizer,
prompt: str, prompt: str,
...@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds( ...@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
...@@ -1238,136 +1266,35 @@ def encode_prompt( ...@@ -1238,136 +1266,35 @@ def encode_prompt(
text_input_ids_list=None, text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].dtype dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _get_clip_prompt_embeds( pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0], text_encoder=text_encoders[0],
tokenizer=tokenizers[0], tokenizer=tokenizers[0],
prompt=prompt, prompt=prompt,
device=device if device is not None else text_encoders[0].device, device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
) )
prompt_embeds = _get_t5_prompt_embeds( prompt_embeds = _encode_prompt_with_t5(
text_encoder=text_encoders[1], text_encoder=text_encoders[1],
tokenizer=tokenizers[1], tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
prompt=prompt, prompt=prompt,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device, device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
) )
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now
# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps
def main(args): def main(args):
if args.report_to == "wandb" and args.hub_token is not None: if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError( raise ValueError(
...@@ -1499,7 +1426,7 @@ def main(args): ...@@ -1499,7 +1426,7 @@ def main(args):
) )
# Load scheduler and models # Load scheduler and models
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler" args.pretrained_model_name_or_path, subfolder="scheduler"
) )
noise_scheduler_copy = copy.deepcopy(noise_scheduler) noise_scheduler_copy = copy.deepcopy(noise_scheduler)
...@@ -1619,7 +1546,6 @@ def main(args): ...@@ -1619,7 +1546,6 @@ def main(args):
target_modules=target_modules, target_modules=target_modules,
) )
transformer.add_adapter(transformer_lora_config) transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder: if args.train_text_encoder:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=args.rank, r=args.rank,
...@@ -1727,7 +1653,6 @@ def main(args): ...@@ -1727,7 +1653,6 @@ def main(args):
cast_training_params(models, dtype=torch.float32) cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder: if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
# if we use textual inversion, we freeze all parameters except for the token embeddings # if we use textual inversion, we freeze all parameters except for the token embeddings
...@@ -1737,7 +1662,8 @@ def main(args): ...@@ -1737,7 +1662,8 @@ def main(args):
for name, param in text_encoder_one.named_parameters(): for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name: if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32) if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
text_lora_parameters_one.append(param) text_lora_parameters_one.append(param)
else: else:
...@@ -1747,7 +1673,8 @@ def main(args): ...@@ -1747,7 +1673,8 @@ def main(args):
for name, param in text_encoder_two.named_parameters(): for name, param in text_encoder_two.named_parameters():
if "shared" in name: if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32) if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.requires_grad = True param.requires_grad = True
text_lora_parameters_two.append(param) text_lora_parameters_two.append(param)
else: else:
...@@ -1828,6 +1755,7 @@ def main(args): ...@@ -1828,6 +1755,7 @@ def main(args):
optimizer_class = bnb.optim.AdamW8bit optimizer_class = bnb.optim.AdamW8bit
else: else:
optimizer_class = torch.optim.AdamW optimizer_class = torch.optim.AdamW
optimizer = optimizer_class( optimizer = optimizer_class(
params_to_optimize, params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
...@@ -2021,6 +1949,7 @@ def main(args): ...@@ -2021,6 +1949,7 @@ def main(args):
lr_scheduler, lr_scheduler,
) )
else: else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
) )
...@@ -2125,7 +2054,7 @@ def main(args): ...@@ -2125,7 +2054,7 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one.train() text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train() text_encoder_one.train()
if args.enable_t5_ti: if args.enable_t5_ti:
...@@ -2137,6 +2066,11 @@ def main(args): ...@@ -2137,6 +2066,11 @@ def main(args):
pivoted_tr = True pivoted_tr = True
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if not freeze_text_encoder:
models_to_accumulate.extend([text_encoder_one])
if args.enable_t5_ti:
models_to_accumulate.extend([text_encoder_two])
if pivoted_te: if pivoted_te:
# stopping optimization of text_encoder params # stopping optimization of text_encoder params
optimizer.param_groups[te_idx]["lr"] = 0.0 optimizer.param_groups[te_idx]["lr"] = 0.0
...@@ -2145,7 +2079,7 @@ def main(args): ...@@ -2145,7 +2079,7 @@ def main(args):
logger.info(f"PIVOT TRANSFORMER {epoch}") logger.info(f"PIVOT TRANSFORMER {epoch}")
optimizer.param_groups[0]["lr"] = 0.0 optimizer.param_groups[0]["lr"] = 0.0
with accelerator.accumulate(transformer): with accelerator.accumulate(models_to_accumulate):
prompts = batch["prompts"] prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image - # encode batch prompts when custom prompts are provided for each image -
...@@ -2189,7 +2123,7 @@ def main(args): ...@@ -2189,7 +2123,7 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
...@@ -2228,7 +2162,7 @@ def main(args): ...@@ -2228,7 +2162,7 @@ def main(args):
) )
# handle guidance # handle guidance
if transformer.config.guidance_embeds: if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -2288,16 +2222,26 @@ def main(args): ...@@ -2288,16 +2222,26 @@ def main(args):
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
if not freeze_text_encoder: if not freeze_text_encoder:
if args.train_text_encoder: if args.train_text_encoder: # text encoder tuning
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
elif pure_textual_inversion: elif pure_textual_inversion:
params_to_clip = itertools.chain( if args.enable_t5_ti:
text_encoder_one.parameters(), text_encoder_two.parameters() params_to_clip = itertools.chain(
) text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(text_encoder_one.parameters())
else: else:
params_to_clip = itertools.chain( if args.enable_t5_ti:
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() params_to_clip = itertools.chain(
) transformer.parameters(),
text_encoder_one.parameters(),
text_encoder_two.parameters(),
)
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters()
)
else: else:
params_to_clip = itertools.chain(transformer.parameters()) params_to_clip = itertools.chain(transformer.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
...@@ -2339,6 +2283,10 @@ def main(args): ...@@ -2339,6 +2283,10 @@ def main(args):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
...@@ -2351,14 +2299,16 @@ def main(args): ...@@ -2351,14 +2299,16 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline # create pipeline
if freeze_text_encoder: if freeze_text_encoder: # no text encoder one, two optimizations
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer), transformer=unwrap_model(transformer),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -2372,21 +2322,21 @@ def main(args): ...@@ -2372,21 +2322,21 @@ def main(args):
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
images = None
del pipeline
if freeze_text_encoder: if freeze_text_encoder:
del text_encoder_one, text_encoder_two del text_encoder_one, text_encoder_two
free_memory() free_memory()
elif args.train_text_encoder:
del text_encoder_two images = None
free_memory() del pipeline
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
transformer = unwrap_model(transformer) transformer = unwrap_model(transformer)
transformer = transformer.to(weight_dtype) if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer) transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder: if args.train_text_encoder:
...@@ -2428,8 +2378,8 @@ def main(args): ...@@ -2428,8 +2378,8 @@ def main(args):
accelerator=accelerator, accelerator=accelerator,
pipeline_args=pipeline_args, pipeline_args=pipeline_args,
epoch=epoch, epoch=epoch,
torch_dtype=weight_dtype,
is_final_validation=True, is_final_validation=True,
torch_dtype=weight_dtype,
) )
save_model_card( save_model_card(
...@@ -2452,6 +2402,7 @@ def main(args): ...@@ -2452,6 +2402,7 @@ def main(args):
commit_message="End of training", commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"], ignore_patterns=["step_*", "epoch_*"],
) )
images = None images = None
del pipeline del pipeline
......
...@@ -895,7 +895,10 @@ def _encode_prompt_with_t5( ...@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -936,9 +939,13 @@ def _encode_prompt_with_clip( ...@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
...@@ -958,7 +965,12 @@ def encode_prompt( ...@@ -958,7 +965,12 @@ def encode_prompt(
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) batch_size = len(prompt)
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip( pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0], text_encoder=text_encoders[0],
...@@ -1590,7 +1602,7 @@ def main(args): ...@@ -1590,7 +1602,7 @@ def main(args):
) )
# handle guidance # handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds: if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1716,9 +1728,9 @@ def main(args): ...@@ -1716,9 +1728,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False), text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False), text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False), transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
......
...@@ -177,16 +177,25 @@ def log_validation( ...@@ -177,16 +177,25 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
) )
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
with autocast_ctx: # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation" phase_name = "test" if is_final_validation else "validation"
...@@ -203,8 +212,7 @@ def log_validation( ...@@ -203,8 +212,7 @@ def log_validation(
) )
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
return images return images
...@@ -932,7 +940,10 @@ def _encode_prompt_with_t5( ...@@ -932,7 +940,10 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
...@@ -973,9 +984,13 @@ def _encode_prompt_with_clip( ...@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel # Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
...@@ -994,7 +1009,11 @@ def encode_prompt( ...@@ -994,7 +1009,11 @@ def encode_prompt(
text_input_ids_list=None, text_input_ids_list=None,
): ):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip( pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0], text_encoder=text_encoders[0],
...@@ -1619,7 +1638,7 @@ def main(args): ...@@ -1619,7 +1638,7 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_encoder_one.train() text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer] models_to_accumulate = [transformer]
...@@ -1710,7 +1729,7 @@ def main(args): ...@@ -1710,7 +1729,7 @@ def main(args):
) )
# handle guidance # handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds: if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0]) guidance = guidance.expand(model_input.shape[0])
else: else:
...@@ -1828,9 +1847,9 @@ def main(args): ...@@ -1828,9 +1847,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
vae=vae, vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two), text_encoder_2=unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer), transformer=unwrap_model(transformer),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment