Unverified Commit 9b638548 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Improve reproduceability 2/3 (#1906)

* [Repro] Correct reproducability

* up

* up

* uP

* up

* need better image

* allow conversion from no state dict checkpoints

* up

* up

* up

* up

* check tensors

* check tensors

* check tensors

* check tensors

* next try

* up

* up

* better name

* up

* up

* Apply suggestions from code review

* correct more

* up

* replace all torch randn

* fix

* correct

* correct

* finish

* fix more

* up
parent 67e2f95c
......@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
......@@ -32,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -267,20 +266,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -32,7 +32,14 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, replace_example_docstring
from ...utils import (
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -464,16 +471,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape
if isinstance(generator, list):
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
......
......@@ -19,14 +19,13 @@ import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -470,20 +469,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -19,7 +19,6 @@ import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
......@@ -33,7 +32,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -414,7 +413,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
init_latents_orig = init_latents
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents, init_latents_orig, noise
......
......@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, randn_tensor
from . import StableDiffusionPipelineOutput
......@@ -308,11 +308,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
......@@ -19,12 +19,11 @@ import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -313,11 +312,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......@@ -450,11 +445,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
else:
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
......
......@@ -18,7 +18,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, is_accelerate_available, logging
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
......@@ -429,20 +429,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -18,6 +18,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import KarrasVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -81,8 +82,7 @@ class KarrasVePipeline(DiffusionPipeline):
model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
sample = randn_tensor(shape, device=self.device) * self.scheduler.init_noise_sigma
self.scheduler.set_timesteps(num_inference_steps)
......
......@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, torch_randn
from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel
......@@ -105,7 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......@@ -499,7 +499,6 @@ class UnCLIPPipeline(DiffusionPipeline):
).prev_sample
image = super_res_latents
# done super res
# post processing
......
......@@ -29,7 +29,7 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, torch_randn
from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel
......@@ -113,7 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
......@@ -29,7 +29,7 @@ from transformers import (
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
......@@ -382,20 +382,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -248,20 +248,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -22,7 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIP
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
......@@ -298,20 +298,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -23,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin
......@@ -324,14 +324,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
if variance_noise is None:
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance
......
......@@ -22,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin
......@@ -313,14 +313,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = 0
if t > 0:
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
else:
......
......@@ -19,7 +19,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
......@@ -230,15 +230,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
prev_sample = prev_sample + noise * sigma_up
......
......@@ -19,7 +19,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
......@@ -217,16 +217,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
......
......@@ -18,7 +18,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput
......@@ -243,15 +243,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
......
......@@ -20,7 +20,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin
......@@ -147,7 +147,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
gamma = 0
# sample eps ~ N(0, S_noise^2 * I)
eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
......
......@@ -20,7 +20,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin
......@@ -271,12 +271,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
# 5. Add noise
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
noise = noise.to(device)
else:
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
variance = 0
......@@ -311,10 +306,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
beta = self.betas[timestep + i]
if sample.device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator)
noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator)
noise = noise.to(sample.device)
else:
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment