Unverified Commit 1bcd19e4 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add pred_original_sample to `if not return_dict` path (#9649)


Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 22ed39f5
...@@ -463,7 +463,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -463,7 +463,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -394,7 +394,10 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -394,7 +394,10 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = a_t * sample + b_t * pred_original_sample prev_sample = a_t * sample + b_t * pred_original_sample
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -480,7 +480,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -480,7 +480,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -492,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -492,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
if not return_dict: if not return_dict:
return (pred_prev_sample,) return (
pred_prev_sample,
pred_original_sample,
)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -500,7 +500,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -500,7 +500,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
if not return_dict: if not return_dict:
return (pred_prev_sample,) return (
pred_prev_sample,
pred_original_sample,
)
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -360,7 +360,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): ...@@ -360,7 +360,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return EulerAncestralDiscreteSchedulerOutput( return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample prev_sample=prev_sample, pred_original_sample=pred_original_sample
......
...@@ -677,7 +677,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -677,7 +677,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -507,7 +507,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -507,7 +507,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
...@@ -320,7 +320,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -320,7 +320,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
if not return_dict: if not return_dict:
return (pred_prev_sample,) return (
pred_prev_sample,
pred_original_sample,
)
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment