Unverified Commit 9b638548 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Improve reproduceability 2/3 (#1906)

* [Repro] Correct reproducability

* up

* up

* uP

* up

* need better image

* allow conversion from no state dict checkpoints

* up

* up

* up

* up

* check tensors

* check tensors

* check tensors

* check tensors

* next try

* up

* up

* better name

* up

* up

* Apply suggestions from code review

* correct more

* up

* replace all torch randn

* fix

* correct

* correct

* finish

* fix more

* up
parent 67e2f95c
...@@ -19,6 +19,7 @@ import tqdm ...@@ -19,6 +19,7 @@ import tqdm
from ...models.unet_1d import UNet1DModel from ...models.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline from ...pipelines import DiffusionPipeline
from ...utils import randn_tensor
from ...utils.dummy_pt_objects import DDPMScheduler from ...utils.dummy_pt_objects import DDPMScheduler
...@@ -127,7 +128,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline): ...@@ -127,7 +128,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
# generate initial noise and apply our conditions (to make the trajectories start at current state) # generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device) x1 = randn_tensor(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim) x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x) x = self.to_torch(x)
......
...@@ -95,7 +95,7 @@ class PriorTransformer(ModelMixin, ConfigMixin): ...@@ -95,7 +95,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
causal_attention_mask = torch.full( causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf") [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
) )
causal_attention_mask.triu_(1) causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...] causal_attention_mask = causal_attention_mask[None, ...]
......
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..utils import BaseOutput from ..utils import BaseOutput, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -323,11 +323,10 @@ class DiagonalGaussianDistribution(object): ...@@ -323,11 +323,10 @@ class DiagonalGaussianDistribution(object):
) )
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype # make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype) sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
)
x = self.mean + self.std * sample x = self.mean + self.std * sample
return x return x
......
...@@ -31,7 +31,7 @@ from ...schedulers import ( ...@@ -31,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging, replace_example_docstring from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
...@@ -401,20 +401,8 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -401,20 +401,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -33,7 +33,7 @@ from ...schedulers import ( ...@@ -33,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
...@@ -461,16 +461,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -461,16 +461,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = torch.cat([init_latents], dim=0) init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
......
...@@ -23,6 +23,7 @@ from PIL import Image ...@@ -23,6 +23,7 @@ from PIL import Image
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from .mel import Mel from .mel import Mel
...@@ -126,7 +127,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): ...@@ -126,7 +127,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
input_dims = self.get_input_dims() input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None: if noise is None:
noise = torch.randn( noise = randn_tensor(
( (
batch_size, batch_size,
self.unet.in_channels, self.unet.in_channels,
......
...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import logging from ...utils import logging, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
...@@ -100,16 +100,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): ...@@ -100,16 +100,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
rand_device = "cpu" if self.device.type == "mps" else self.device audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
audio = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
audio = torch.cat(audio, dim=0).to(self.device)
else:
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps, device=audio.device) self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
......
...@@ -16,7 +16,7 @@ from typing import List, Optional, Tuple, Union ...@@ -16,7 +16,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...utils import deprecate from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -103,17 +103,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -103,17 +103,7 @@ class DDIMPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
rand_device = "cpu" if self.device.type == "mps" else self.device image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
if isinstance(generator, list):
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union ...@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...utils import deprecate from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -100,10 +100,10 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -100,10 +100,10 @@ class DDPMPipeline(DiffusionPipeline):
if self.device.type == "mps": if self.device.type == "mps":
# randn does not work reproducibly on mps # randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator) image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device) image = image.to(self.device)
else: else:
image = torch.randn(image_shape, generator=generator, device=self.device) image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -26,6 +26,7 @@ from transformers.utils import logging ...@@ -26,6 +26,7 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -143,20 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -143,20 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if self.device.type == "mps" else self.device latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
if isinstance(generator, list):
latents_shape = (1,) + latents_shape[1:]
latents = [
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = torch.randn(
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
)
latents = latents.to(self.device)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
......
...@@ -16,7 +16,7 @@ from ...schedulers import ( ...@@ -16,7 +16,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate from ...utils import PIL_INTERPOLATION, deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -121,12 +121,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline): ...@@ -121,12 +121,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
latents_shape = (batch_size, self.unet.in_channels // 2, height, width) latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype latents_dtype = next(self.unet.parameters()).dtype
if self.device.type == "mps": latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
latents = latents.to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
image = image.to(device=self.device, dtype=latents_dtype) image = image.to(device=self.device, dtype=latents_dtype)
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel, VQModel from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -71,7 +72,7 @@ class LDMPipeline(DiffusionPipeline): ...@@ -71,7 +72,7 @@ class LDMPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
""" """
latents = torch.randn( latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
) )
......
...@@ -34,7 +34,7 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel): ...@@ -34,7 +34,7 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
self.proj_out = nn.Linear(config.hidden_size, self.proj_size) self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
# uncondition for scaling # uncondition for scaling
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size))) self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
def forward(self, pixel_values): def forward(self, pixel_values):
clip_output = self.model(pixel_values=pixel_values) clip_output = self.model(pixel_values=pixel_values)
......
...@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor ...@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -300,20 +300,8 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -300,20 +300,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else: else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import PNDMScheduler from ...schedulers import PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -72,11 +73,11 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -72,11 +73,11 @@ class PNDMPipeline(DiffusionPipeline):
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
device=self.device,
) )
image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
......
...@@ -22,7 +22,7 @@ import PIL ...@@ -22,7 +22,7 @@ import PIL
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import RePaintScheduler from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -143,18 +143,8 @@ class RePaintPipeline(DiffusionPipeline): ...@@ -143,18 +143,8 @@ class RePaintPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
rand_device = "cpu" if self.device.type == "mps" else self.device
image_shape = original_image.shape image_shape = original_image.shape
if isinstance(generator, list): image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
......
...@@ -18,6 +18,7 @@ import torch ...@@ -18,6 +18,7 @@ import torch
from ...models import UNet2DModel from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler from ...schedulers import ScoreSdeVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -69,7 +70,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -69,7 +70,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet model = self.unet
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device) sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -26,7 +26,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -26,7 +26,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -76,7 +76,7 @@ def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta ...@@ -76,7 +76,7 @@ def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta
# direction pointing to x_t # direction pointing to x_t
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
noise = std_dev_t * torch.randn( noise = std_dev_t * randn_tensor(
clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
) )
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
...@@ -472,16 +472,8 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -472,16 +472,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timestep # add noise to latents using the timestep
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
clean_latents = init_latents clean_latents = init_latents
......
...@@ -30,7 +30,7 @@ from ...schedulers import ( ...@@ -30,7 +30,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, is_accelerate_available, logging, replace_example_docstring from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
...@@ -398,20 +398,8 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -398,20 +398,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
) )
if latents is None: if latents is None:
rand_device = "cpu" if device.type == "mps" else device latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else: else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device) latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
......
...@@ -20,7 +20,6 @@ import numpy as np ...@@ -20,7 +20,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers.utils import is_accelerate_available
from packaging import version from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
...@@ -34,7 +33,7 @@ from ...schedulers import ( ...@@ -34,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -381,16 +380,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -381,16 +380,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else: else:
init_latents = torch.cat([init_latents], dim=0) init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape shape = init_latents.shape
if isinstance(generator, list): noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
# get latents # get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment