Unverified Commit 88fa6b7d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Dance Diffusion] Add dance diffusion (#803)



* start

* add more logic

* Update src/diffusers/models/unet_2d_condition_flax.py

* match weights

* up

* make model work

* making class more general, fixing missed file rename

* small fix

* make new conversion work

* up

* finalize conversion

* up

* first batch of variable renamings

* remove c and c_prev var names

* add mid and out block structure

* add pipeline

* up

* finish conversion

* finish

* upload

* more fixes

* Apply suggestions from code review

* add attr

* up

* uP

* up

* finish tests

* finish

* uP

* finish

* fix test

* up

* naming consistency in tests

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarNathan Lambert <nathan@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* remove hardcoded 16

* Remove bogus

* fix some stuff

* finish

* improve logging

* docs

* upload
Co-authored-by: default avatarNathan Lambert <nol@berkeley.edu>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarNathan Lambert <nathan@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent 0b42b074
...@@ -19,6 +19,7 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available ...@@ -19,6 +19,7 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available
if is_torch_available(): if is_torch_available():
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler
......
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion
[library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296)
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
"""
@register_to_config
def __init__(self, num_train_timesteps: int = 1000):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
# running values
self.ets = []
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
steps = torch.cat([steps, torch.tensor([0.0])])
self.betas = torch.sin(steps * math.pi / 2) ** 2
self.alphas = (1.0 - self.betas**2) ** 0.5
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
self.timesteps = timesteps.to(device)
self.ets = []
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
timestep_index = (self.timesteps == timestep).nonzero().item()
prev_timestep_index = timestep_index + 1
ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
self.ets.append(ets)
if len(self.ets) == 1:
ets = self.ets[-1]
elif len(self.ets) == 2:
ets = (3 * self.ets[-1] - self.ets[-2]) / 2
elif len(self.ets) == 3:
ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
else:
ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
alpha = self.alphas[timestep_index]
sigma = self.betas[timestep_index]
next_alpha = self.alphas[prev_timestep_index]
next_sigma = self.betas[prev_timestep_index]
pred = (sample - sigma * ets) / max(alpha, 1e-8)
prev_sample = next_alpha * pred + ets * next_sigma
return prev_sample
def __len__(self):
return self.config.num_train_timesteps
...@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject): ...@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet2DConditionModel(metaclass=DummyObject): class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -122,6 +137,21 @@ class DiffusionPipeline(metaclass=DummyObject): ...@@ -122,6 +137,21 @@ class DiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMPipeline(metaclass=DummyObject): class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -242,6 +272,21 @@ class DDPMScheduler(metaclass=DummyObject): ...@@ -242,6 +272,21 @@ class DDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class IPNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class KarrasVeScheduler(metaclass=DummyObject): class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False
class PipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def dummy_unet(self):
torch.manual_seed(0)
model = UNet1DModel(
block_out_channels=(32, 32, 64),
extra_in_channels=16,
sample_size=512,
sample_rate=16_000,
in_channels=2,
out_channels=2,
down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"],
up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"],
)
return model
def test_dance_diffusion(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
scheduler = IPNDMScheduler()
pipe = DanceDiffusionPipeline(unet=self.dummy_unet, scheduler=scheduler)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(generator=generator, num_inference_steps=4)
audio = output.audios
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(generator=generator, num_inference_steps=4, return_dict=False)
audio_from_tuple = output[0]
audio_slice = audio[0, -3:, -3:]
audio_from_tuple_slice = audio_from_tuple[0, -3:, -3:]
assert audio.shape == (1, 2, self.dummy_unet.sample_size)
expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000])
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(audio_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class PipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_dance_diffusion(self):
device = torch_device
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096)
audio = output.audios
audio_slice = audio[0, -3:, -3:]
assert audio.shape == (1, 2, pipe.unet.sample_size)
expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487])
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import UNet1DModel
from diffusers.utils import slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
class UnetModel1DTests(unittest.TestCase):
@slow
def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k"
model = UNet1DModel.from_pretrained(model_id, subfolder="unet")
model.to(torch_device)
sample_size = 65536
noise = torch.sin(torch.arange(sample_size)[None, None, :].repeat(1, 2, 1)).to(torch_device)
timestep = torch.tensor([1]).to(torch_device)
with torch.no_grad():
output = model(noise, timestep).sample
output_sum = output.abs().sum()
output_max = output.abs().max()
assert (output_sum - 224.0896).abs() < 4e-2
assert (output_max - 0.0607).abs() < 4e-4
...@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin ...@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class UnetModelTests(ModelTesterMixin, unittest.TestCase): class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
@property @property
......
...@@ -19,7 +19,14 @@ from typing import Dict, List, Tuple ...@@ -19,7 +19,14 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler from diffusers import (
DDIMScheduler,
DDPMScheduler,
IPNDMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
ScoreSdeVeScheduler,
)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -203,6 +210,10 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -203,6 +210,10 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", 50) num_inference_steps = kwargs.pop("num_inference_steps", 50)
timestep = 0
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
timestep = 1
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -215,14 +226,14 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -215,14 +226,14 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(residual, 0, sample, **kwargs) outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs) outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict) recursive_check(outputs_tuple, outputs_dict)
...@@ -901,3 +912,157 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -901,3 +912,157 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3 assert abs(result_mean.item() - 1.31) < 1e-3
class IPNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (IPNDMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {"num_train_timesteps": 1000}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
if time_step is None:
time_step = scheduler.timesteps[len(scheduler.timesteps) // 2]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.ets = dummy_past_residuals[:]
if time_step is None:
time_step = scheduler.timesteps[len(scheduler.timesteps) // 2]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
time_step_0 = scheduler.timesteps[5]
time_step_1 = scheduler.timesteps[6]
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps, time_step=None)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=None)
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 2540529) < 10
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment