Unverified Commit 9b638548 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Improve reproduceability 2/3 (#1906)

* [Repro] Correct reproducability

* up

* up

* uP

* up

* need better image

* allow conversion from no state dict checkpoints

* up

* up

* up

* up

* check tensors

* check tensors

* check tensors

* check tensors

* next try

* up

* up

* better name

* up

* up

* Apply suggestions from code review

* correct more

* up

* replace all torch randn

* fix

* correct

* correct

* finish

* fix more

* up
parent 67e2f95c
...@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
...@@ -32,7 +31,7 @@ from ...schedulers import ( ...@@ -32,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -267,20 +266,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -267,20 +266,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -32,7 +32,14 @@ from ...schedulers import ( ...@@ -32,7 +32,14 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, replace_example_docstring from ...utils import (
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -464,16 +471,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -464,16 +471,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = torch.cat([init_latents], dim=0) init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
......
...@@ -19,14 +19,13 @@ import numpy as np ...@@ -19,14 +19,13 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -470,20 +469,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -470,20 +469,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
...@@ -33,7 +32,7 @@ from ...schedulers import ( ...@@ -33,7 +32,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -414,7 +413,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -414,7 +413,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
init_latents_orig = init_latents init_latents_orig = init_latents
# add noise to latents using the timesteps # add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype) noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents latents = init_latents
return latents, init_latents_orig, noise return latents, init_latents_orig, noise
......
...@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser ...@@ -21,7 +21,7 @@ from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
...@@ -308,11 +308,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -308,11 +308,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None: if latents is None:
if device.type == "mps": latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
...@@ -19,12 +19,11 @@ import numpy as np ...@@ -19,12 +19,11 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -313,11 +312,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -313,11 +312,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
if latents is None: if latents is None:
if device.type == "mps": latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
...@@ -450,11 +445,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -450,11 +445,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 5. Add noise to image # 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
if device.type == "mps": noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
# randn does not work reproducibly on mps
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
else:
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1 batch_multiplier = 2 if do_classifier_free_guidance else 1
......
...@@ -18,7 +18,7 @@ from ...schedulers import ( ...@@ -18,7 +18,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, is_accelerate_available, logging from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
...@@ -429,20 +429,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -429,20 +429,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import KarrasVeScheduler from ...schedulers import KarrasVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -81,8 +82,7 @@ class KarrasVePipeline(DiffusionPipeline): ...@@ -81,8 +82,7 @@ class KarrasVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I) # sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma sample = randn_tensor(shape, device=self.device) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput ...@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, torch_randn from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -105,7 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -105,7 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
...@@ -499,7 +499,6 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -499,7 +499,6 @@ class UnCLIPPipeline(DiffusionPipeline):
).prev_sample ).prev_sample
image = super_res_latents image = super_res_latents
# done super res # done super res
# post processing # post processing
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, torch_randn from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -113,7 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -113,7 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
...@@ -382,20 +382,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -382,20 +382,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection ...@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -248,20 +248,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -248,20 +248,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -22,7 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIP ...@@ -22,7 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIP
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
...@@ -298,20 +298,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -298,20 +298,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -324,14 +324,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -324,14 +324,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
) )
if variance_noise is None: if variance_noise is None:
if device.type == "mps": variance_noise = randn_tensor(
# randn does not work reproducibly on mps model_output.shape, generator=generator, device=device, dtype=model_output.dtype
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) )
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance prev_sample = prev_sample + variance
......
...@@ -22,7 +22,7 @@ import numpy as np ...@@ -22,7 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -313,14 +313,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -313,14 +313,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = 0 variance = 0
if t > 0: if t > 0:
device = model_output.device device = model_output.device
if device.type == "mps": variance_noise = randn_tensor(
# randn does not work reproducibly on mps model_output.shape, generator=generator, device=device, dtype=model_output.dtype
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) )
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
if self.variance_type == "fixed_small_log": if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
else: else:
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -230,15 +230,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -230,15 +230,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
device = model_output.device device = model_output.device
if device.type == "mps": noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
prev_sample = prev_sample + noise * sigma_up prev_sample = prev_sample + noise * sigma_up
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -217,16 +217,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -217,16 +217,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
device = model_output.device noise = randn_tensor(
if device.type == "mps": model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
# randn does not work reproducibly on mps )
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
eps = noise * s_noise eps = noise * s_noise
sigma_hat = sigma * (gamma + 1) sigma_hat = sigma * (gamma + 1)
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
...@@ -243,15 +243,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,15 +243,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
device = model_output.device device = model_output.device
if device.type == "mps": noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -147,7 +147,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -147,7 +147,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
gamma = 0 gamma = 0
# sample eps ~ N(0, S_noise^2 * I) # sample eps ~ N(0, S_noise^2 * I)
eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
sigma_hat = sigma + gamma * sigma sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -271,12 +271,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -271,12 +271,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
# 5. Add noise # 5. Add noise
device = model_output.device device = model_output.device
if device.type == "mps": noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
noise = noise.to(device)
else:
noise = torch.randn(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 std_dev_t = self.eta * self._get_variance(timestep) ** 0.5
variance = 0 variance = 0
...@@ -311,10 +306,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): ...@@ -311,10 +306,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
beta = self.betas[timestep + i] beta = self.betas[timestep + i]
if sample.device.type == "mps": if sample.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
noise = torch.randn(sample.shape, dtype=sample.dtype, generator=generator) noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator)
noise = noise.to(sample.device) noise = noise.to(sample.device)
else: else:
noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
# 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment