Unverified Commit 63a0c9e5 authored by G.O.D's avatar G.O.D Committed by GitHub
Browse files

[bugfix] reduce float value error when adding noise (#9004)



* Update train_controlnet.py

reduce float value error for bfloat16

* Update train_controlnet_sdxl.py

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent e2d037bb
...@@ -1048,7 +1048,9 @@ def main(args): ...@@ -1048,7 +1048,9 @@ def main(args):
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
dtype=weight_dtype
)
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
......
...@@ -1210,7 +1210,9 @@ def main(args): ...@@ -1210,7 +1210,9 @@ def main(args):
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
dtype=weight_dtype
)
# ControlNet conditioning. # ControlNet conditioning.
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment