Commit 394243ce authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish pndm sampler

parent fe985746
...@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]
# 2. predict previous mean of image x_t-1 # 2. compute previous image: x_t -> t_t-1
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"] image = self.scheduler.step(model_output, t, image)["prev_sample"]
# 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
image = image.to(torch_device) image = image.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)): for t in tqdm(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"] image = self.scheduler.step(model_output, t, image)["prev_sample"]
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)): # for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
model_output = self.unet(image, t)["sample"] # model_output = self.unet(image, t)["sample"]
#
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"] # image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"]
#
# for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
# model_output = self.unet(image, t)["sample"]
#
# image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for i, t in tqdm(enumerate(self.scheduler.timesteps)): for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
# correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
model_output = self.model(sample, sigma_t) model_output = self.model(sample, sigma_t)["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
with torch.no_grad(): # prediction step
model_output = model(sample, sigma_t) model_output = model(sample, sigma_t)["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
output = self.scheduler.step_pred(model_output, t, sample) output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
sample = sample.clamp(0, 1) sample = sample.clamp(0, 1)
......
...@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
eta, eta: float = 0.0,
use_clipped_model_output=False, use_clipped_model_output: bool = False,
generator=None, generator=None,
): ):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
......
...@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
trained_betas=None, trained_betas=None,
timestep_values=None,
variance_type="fixed_small", variance_type="fixed_small",
clip_sample=True, clip_sample=True,
tensor_format="pt", tensor_format="pt",
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import pdb
from typing import Union from typing import Union
import numpy as np import numpy as np
...@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# running values # running values
self.cur_model_output = 0 self.cur_model_output = 0
self.counter = 0
self.cur_sample = None self.cur_sample = None
self.ets = [] self.ets = []
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None self.prk_timesteps = None
self.plms_timesteps = None self.plms_timesteps = None
self.timesteps = None
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = list( self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
) )
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile( prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
) )
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self.timesteps[:-3])) self.plms_timesteps = list(reversed(self._timesteps[:-3]))
self.timesteps = self.prk_timesteps + self.plms_timesteps
self.counter = 0
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
if self.counter < len(self.prk_timesteps):
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
def step_prk( def step_prk(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
): ):
""" """
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation. solution to the differential equation.
""" """
t = timestep diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prk_time_steps = self.prk_timesteps prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
timestep = self.prk_timesteps[self.counter // 4 * 4]
t_orig = prk_time_steps[t // 4 * 4]
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
if t % 4 == 0: if self.counter % 4 == 0:
self.cur_model_output += 1 / 6 * model_output self.cur_model_output += 1 / 6 * model_output
self.ets.append(model_output) self.ets.append(model_output)
self.cur_sample = sample self.cur_sample = sample
elif (t - 1) % 4 == 0: elif (self.counter - 1) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output self.cur_model_output += 1 / 3 * model_output
elif (t - 2) % 4 == 0: elif (self.counter - 2) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output self.cur_model_output += 1 / 3 * model_output
elif (t - 3) % 4 == 0: elif (self.counter - 3) % 4 == 0:
model_output = self.cur_model_output + 1 / 6 * model_output model_output = self.cur_model_output + 1 / 6 * model_output
self.cur_model_output = 0 self.cur_model_output = 0
# cur_sample should not be `None` # cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample cur_sample = self.cur_sample if self.cur_sample is not None else sample
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, model_output)} prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
def step_plms( def step_plms(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
): ):
""" """
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
""" """
t = timestep
if len(self.ets) < 3: if len(self.ets) < 3:
raise ValueError( raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run " f"{self.__class__} can only be run AFTER scheduler has been run "
...@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information." "for more information."
) )
timesteps = self.plms_timesteps prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(model_output) self.ets.append(model_output)
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, model_output)} prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output): def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9) # this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation # Note that x_t needs to be added to both sides of the equation
...@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t # sample -> x_t
# model_output -> e_θ(x_t, t) # model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[t_orig + 1] alpha_prod_t = self.alphas_cumprod[timestep + 1]
alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1] alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pdb
import tempfile import tempfile
import unittest import unittest
...@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def check_over_configs(self, time_step=0, **config): def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
...@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"]) scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals # copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:] scheduler.ets = dummy_past_residuals[:]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(kwargs["num_inference_steps"]) new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
...@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def check_over_forward(self, time_step=0, **forward_kwargs): def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
...@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"]) scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals # copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:] scheduler.ets = dummy_past_residuals[:]
...@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals # copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:] new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_timesteps(kwargs["num_inference_steps"]) new_scheduler.set_timesteps(num_inference_steps)
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
...@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
...@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
...@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
...@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for i, t in enumerate(scheduler.prk_timesteps): for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"] sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
for i, t in enumerate(scheduler.plms_timesteps): for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"] sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
...@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
model_output = model(sample, sigma_t) model_output = model(sample, sigma_t)
output = scheduler.step_pred(model_output, t, sample, **kwargs) output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, _ = output["prev_sample"], output["prev_sample_mean"]
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment