Unverified Commit 81331f3b authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

Add x-prediction / prediction_type=sample support for SDXL fine-tuning (#5095)


Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 29970757
...@@ -998,6 +998,11 @@ def main(args): ...@@ -998,6 +998,11 @@ def main(args):
target = noise target = noise
elif noise_scheduler.config.prediction_type == "v_prediction": elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps) target = noise_scheduler.get_velocity(model_input, noise, timesteps)
elif noise_scheduler.config.prediction_type == "sample":
# We set the target to latents here, but the model_pred will return the noise sample prediction.
target = model_input
# We will have to subtract the noise residual from the prediction to get the target sample.
model_pred = model_pred - noise
else: else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment