Unverified Commit c5f04d4e authored by jiqing-feng's avatar jiqing-feng Committed by GitHub
Browse files

apply amp bf16 on textual inversion (#1465)

* add conf.yaml

* enable bf16

enable amp bf16 for unet forward

fix style

fix readme

remove useless file

* change amp to full bf16

* align

* make stype

* fix format
parent 61dec533
...@@ -532,9 +532,15 @@ def main(): ...@@ -532,9 +532,15 @@ def main():
) )
accelerator.register_for_checkpointing(lr_scheduler) accelerator.register_for_checkpointing(lr_scheduler)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae and unet to device # Move vae and unet to device
vae.to(accelerator.device) unet.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device) vae.to(accelerator.device, dtype=weight_dtype)
# Keep vae and unet in eval model as we don't train these # Keep vae and unet in eval model as we don't train these
vae.eval() vae.eval()
...@@ -600,11 +606,11 @@ def main(): ...@@ -600,11 +606,11 @@ def main():
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
# Convert images to latent space # Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215 latents = latents * 0.18215
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device) noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
...@@ -616,7 +622,7 @@ def main(): ...@@ -616,7 +622,7 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0] encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
# Predict the noise residual # Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
...@@ -629,7 +635,7 @@ def main(): ...@@ -629,7 +635,7 @@ def main():
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean() loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment