"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "249a9e48e8f8aac4356d5a285c8ba0c600a80f64"
Unverified Commit bc108e15 authored by 39th president of the United States, probably's avatar 39th president of the United States, probably Committed by GitHub
Browse files

Fix DREAM training (#8302)

Co-authored-by: Jimmy <39@🇺🇸

.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 86555c9f
...@@ -157,19 +157,19 @@ def compute_dream_and_update_latents( ...@@ -157,19 +157,19 @@ def compute_dream_and_update_latents(
with torch.no_grad(): with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noisy_latents, target = (None, None) _noisy_latents, _target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred predicted_noise = pred
delta_noise = (noise - predicted_noise).detach() delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda) delta_noise.mul_(dream_lambda)
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) _noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
target = target.add(delta_noise) _target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction": elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction") raise NotImplementedError("DREAM has not been implemented for v-prediction")
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
return noisy_latents, target return _noisy_latents, _target
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment