Unverified Commit 182b164f authored by Nathan Lambert's avatar Nathan Lambert Committed by GitHub
Browse files

Fix VE SDE tests, clean API (#95)



* clean ddpm api to match ddim

* correct ve sde class

* update pipeline API for ve sde

* make style

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8b42c7ce
...@@ -40,8 +40,10 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -40,8 +40,10 @@ class DDPMPipeline(DiffusionPipeline):
) )
image = image.to(torch_device) image = image.to(torch_device)
num_prediction_steps = len(self.scheduler) # set step values
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): self.scheduler.set_timesteps(1000)
for t in tqdm(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
with torch.no_grad(): with torch.no_grad():
model_output = self.unet(image, t) model_output = self.unet(image, t)
......
...@@ -18,8 +18,8 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -18,8 +18,8 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.model.to(device) model = self.model.to(device)
x = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device) sample = sample.to(device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps)
...@@ -29,19 +29,20 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -29,19 +29,20 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
with torch.no_grad(): with torch.no_grad():
result = self.model(x, sigma_t) model_output = self.model(sample, sigma_t)
if isinstance(result, dict): if isinstance(model_output, dict):
result = result["sample"] model_output = model_output["sample"]
x = self.scheduler.step_correct(result, x) sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
with torch.no_grad(): with torch.no_grad():
result = model(x, sigma_t) model_output = model(sample, sigma_t)
if isinstance(result, dict): if isinstance(model_output, dict):
result = result["sample"] model_output = model_output["sample"]
x, x_mean = self.scheduler.step_pred(result, x, t) output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
return x_mean return sample_mean
...@@ -86,8 +86,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -86,8 +86,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0) self.one = np.array(1.0)
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
def get_variance(self, t, variance_type=None): def get_variance(self, t, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
from typing import Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,8 +29,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -27,8 +29,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
""" """
The variance exploding stochastic differential equation (SDE) scheduler. The variance exploding stochastic differential equation (SDE) scheduler.
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param :param snr: coefficient weighting the step from the model_output sample (from the network) to the random noise.
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the :param sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data. distribution of the data.
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to :param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format: epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
...@@ -54,12 +56,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -54,12 +56,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sampling_eps=sampling_eps, sampling_eps=sampling_eps,
correct_steps=correct_steps, correct_steps=correct_steps,
) )
# self.sigmas = None
self.sigmas = None # self.discrete_sigmas = None
self.discrete_sigmas = None #
# # setable values
# self.num_inference_steps = None
self.timesteps = None self.timesteps = None
# TODO - update step to be torch-independant self.set_sigmas(self.num_train_timesteps)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
...@@ -104,52 +110,80 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -104,52 +110,80 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(self, score, x, t): def set_seed(self, seed):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
np.random.seed(seed)
elif tensor_format == "pt":
torch.manual_seed(seed)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
):
""" """
Predict the sample at the previous timestep by reversing the SDE. Predict the sample at the previous timestep by reversing the SDE.
""" """
# TODO(Patrick) better comments + non-PyTorch if seed is not None:
t = self.repeat_scalar(t, x.shape[0]).to(x.device) self.set_seed(seed)
timesteps = self.long((t * (len(self.timesteps) - 1))).to(x.device) # TODO(Patrick) non-PyTorch
timestep = timestep * torch.ones(
sample.shape[0], device=sample.device
) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long()
sigma = self.discrete_sigmas[timesteps].to(x.device) sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, t) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
drift = self.zeros_like(x) drift = self.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the score modeled by the network is grad_x log pt(x) # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods # also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * score drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
noise = self.randn_like(x) noise = self.randn_like(sample)
x_mean = x - drift # subtract because `dt` is a small negative timestep prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return x, x_mean
def step_correct(self, score, x): return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
):
""" """
Correct the predicted sample based on the output score of the network. This is often run repeatedly after Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
making the prediction for the previous timestep. after making the prediction for the previous timestep.
""" """
# TODO(Patrick) non-PyTorch if seed is not None:
self.set_seed(seed)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = self.randn_like(x) noise = self.randn_like(sample)
# compute step size from the score, the noise, and the snr # compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(score) grad_norm = self.norm(model_output)
noise_norm = self.norm(noise) noise_norm = self.norm(noise)
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device) step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
# self.repeat_scalar(step_size, sample.shape[0])
# compute corrected sample: score term and noise term # compute corrected sample: model_output term and noise term
x_mean = x + step_size[:, None, None, None] * score prev_sample_mean = sample + step_size[:, None, None, None] * model_output
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return x return {"prev_sample": prev_sample}
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -53,16 +53,6 @@ class SchedulerMixin: ...@@ -53,16 +53,6 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def long(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.int64(tensor)
elif tensor_format == "pt":
return tensor.long()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
""" """
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
...@@ -103,15 +93,6 @@ class SchedulerMixin: ...@@ -103,15 +93,6 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def repeat_scalar(self, tensor, count):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.repeat(tensor, count)
elif tensor_format == "pt":
return torch.repeat_interleave(tensor, count)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def zeros_like(self, tensor): def zeros_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt") tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": if tensor_format == "np":
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pdb
import tempfile import tempfile
import unittest import unittest
...@@ -507,8 +508,42 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -507,8 +508,42 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2593) < 1e-3 assert abs(result_mean.item() - 0.2593) < 1e-3
class ScoreSdeVeSchedulerTest(SchedulerCommonTest): class ScoreSdeVeSchedulerTest(unittest.TestCase):
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration)
scheduler_classes = (ScoreSdeVeScheduler,) scheduler_classes = (ScoreSdeVeScheduler,)
forward_default_kwargs = (("seed", 0),)
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
sample = torch.rand((batch_size, num_channels, height, width))
return sample
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
num_elems = batch_size * num_channels * height * width
sample = torch.arange(num_elems)
sample = sample.reshape(num_channels, height, width, batch_size)
sample = sample / num_elems
sample = sample.permute(3, 0, 1, 2)
return sample
def dummy_model(self):
def model(sample, t, *args):
return sample * t / (t + 1)
return model
def get_scheduler_config(self, **kwargs): def get_scheduler_config(self, **kwargs):
config = { config = {
...@@ -517,7 +552,7 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -517,7 +552,7 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
"sigma_min": 0.01, "sigma_min": 0.01,
"sigma_max": 1348, "sigma_max": 1348,
"sampling_eps": 1e-5, "sampling_eps": 1e-5,
"tensor_format": "np", # TODO add test for tensor formats "tensor_format": "pt", # TODO add test for tensor formats
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -538,15 +573,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -538,15 +573,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs) output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs) new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs) output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs) new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs): def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -564,15 +599,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -564,15 +599,15 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs) output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs) new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs) output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs) new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def test_timesteps(self): def test_timesteps(self):
for timesteps in [10, 100, 1000]: for timesteps in [10, 100, 1000]:
...@@ -583,11 +618,12 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -583,11 +618,12 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max) self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
def test_time_indices(self): def test_time_indices(self):
for t in [1, 5, 10]: for t in [0.1, 0.5, 0.75]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
np.random.seed(0) kwargs = dict(self.forward_default_kwargs)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -598,52 +634,27 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -598,52 +634,27 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
sample = self.dummy_sample_deter sample = self.dummy_sample_deter
scheduler.set_sigmas(num_inference_steps) scheduler.set_sigmas(num_inference_steps)
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
sigma_t = scheduler.sigmas[i] sigma_t = scheduler.sigmas[i]
for _ in range(scheduler.correct_steps): for _ in range(scheduler.correct_steps):
with torch.no_grad(): with torch.no_grad():
result = model(sample, sigma_t) model_output = model(sample, sigma_t)
sample = scheduler.step_correct(result, sample) sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
with torch.no_grad(): with torch.no_grad():
result = model(sample, sigma_t) model_output = model(sample, sigma_t)
sample, sample_mean = scheduler.step_pred(result, sample, t)
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 10629923278.7104) < 1e-2
assert abs(result_mean.item() - 13841045.9358) < 1e-3
def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): output = scheduler.step_pred(model_output, t, sample, **kwargs)
scheduler.set_timesteps(num_inference_steps) sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"] result_sum = torch.sum(torch.abs(sample))
new_output = new_scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"] result_mean = torch.mean(torch.abs(sample))
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert abs(result_sum.item() - 14224664576.0) < 1e-2
assert abs(result_mean.item() - 18521698.0) < 1e-3
def test_step_shape(self): def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -667,31 +678,3 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -667,31 +678,3 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
sample_pt = torch.tensor(sample)
residual_pt = 0.1 * sample_pt
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_pred(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment