"examples/vscode:/vscode.git/clone" did not exist on "c5e683c439c9b7428b772ebc93edb42b4b422ced"
Unverified Commit 1bcd19e4 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add pred_original_sample to `if not return_dict` path (#9649)


Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 22ed39f5
......@@ -463,7 +463,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -394,7 +394,10 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = a_t * sample + b_t * pred_original_sample
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -480,7 +480,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -492,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return (
pred_prev_sample,
pred_original_sample,
)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -500,7 +500,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return (
pred_prev_sample,
pred_original_sample,
)
return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -360,7 +360,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -435,7 +435,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
......
......@@ -677,7 +677,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -507,7 +507,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
......
......@@ -320,7 +320,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return (
pred_prev_sample,
pred_original_sample,
)
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment