Unverified Commit bc108e15 authored by 39th president of the United States, probably's avatar 39th president of the United States, probably Committed by GitHub
Browse files

Fix DREAM training (#8302)

Co-authored-by: Jimmy <39@🇺🇸

.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 86555c9f
......@@ -157,19 +157,19 @@ def compute_dream_and_update_latents(
with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noisy_latents, target = (None, None)
_noisy_latents, _target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
target = target.add(delta_noise)
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
_target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction")
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
return noisy_latents, target
return _noisy_latents, _target
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment