Commit f82ebb9a authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix some model tests

parent 63c68d97
...@@ -35,7 +35,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -35,7 +35,7 @@ class DDPMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device)
......
...@@ -90,6 +90,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -90,6 +90,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
...@@ -102,9 +105,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,9 +105,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.config.timesteps, self.config.timesteps // self.num_inference_steps)[ self.timesteps = np.arange(
::-1 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
].copy() )[::-1].copy()
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
def step( def step(
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
import numpy as np import numpy as np
import torch import torch
...@@ -110,10 +109,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -110,10 +109,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Predict the sample at the previous timestep by reversing the SDE. Predict the sample at the previous timestep by reversing the SDE.
""" """
# TODO(Patrick) better comments + non-PyTorch # TODO(Patrick) better comments + non-PyTorch
t = self.repeat_scalar(t, x.shape[0]) t = self.repeat_scalar(t, x.shape[0]).to(x.device)
timesteps = self.long((t * (len(self.timesteps) - 1))) timesteps = self.long((t * (len(self.timesteps) - 1))).to(x.device)
sigma = self.discrete_sigmas[timesteps] sigma = self.discrete_sigmas[timesteps].to(x.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, t) adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
drift = self.zeros_like(x) drift = self.zeros_like(x)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
......
...@@ -911,7 +911,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -911,7 +911,7 @@ class PipelineTesterMixin(unittest.TestCase):
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
) )
schedular = DDPMScheduler(timesteps=10) schedular = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, schedular) ddpm = DDPMPipeline(model, schedular)
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pdb
import tempfile import tempfile
import unittest import unittest
...@@ -618,3 +617,81 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest): ...@@ -618,3 +617,81 @@ class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 10629923278.7104) < 1e-2 assert abs(result_sum.item() - 10629923278.7104) < 1e-2
assert abs(result_mean.item() - 13841045.9358) < 1e-3 assert abs(result_mean.item() - 13841045.9358) < 1e-3
def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
sample_pt = torch.tensor(sample)
residual_pt = 0.1 * sample_pt
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_pred(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment