Unverified Commit e1b5b8ba authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure fp16-fix is used as default (#4510)

* Make sue fp16-fix is used as default

* fix vae

* finish

* fix
parent dff5ff35
......@@ -73,17 +73,19 @@ Now, we can launch training using:
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="lora-trained-xl"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--pretrained_vae_model_name_or_path=$VAE_PATH \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
......
......@@ -732,12 +732,11 @@ def main(args):
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
# The VAE is always in float32 to avoid NaN losses.
vae.to(accelerator.device, dtype=torch.float32)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
......@@ -1070,10 +1069,7 @@ def main(args):
continue
with accelerator.accumulate(unet):
if args.pretrained_vae_model_name_or_path is None:
pixel_values = batch["pixel_values"]
else:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment