Unverified Commit 9b638548 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Improve reproduceability 2/3 (#1906)

* [Repro] Correct reproducability

* up

* up

* uP

* up

* need better image

* allow conversion from no state dict checkpoints

* up

* up

* up

* up

* check tensors

* check tensors

* check tensors

* check tensors

* next try

* up

* up

* better name

* up

* up

* Apply suggestions from code review

* correct more

* up

* replace all torch randn

* fix

* correct

* correct

* finish

* fix more

* up
parent 67e2f95c
......@@ -19,6 +19,7 @@ import tqdm
from ...models.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline
from ...utils import randn_tensor
from ...utils.dummy_pt_objects import DDPMScheduler
......@@ -127,7 +128,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device)
x1 = randn_tensor(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)
......
......@@ -95,7 +95,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
......
......@@ -18,7 +18,7 @@ import numpy as np
import torch
import torch.nn as nn
from ..utils import BaseOutput
from ..utils import BaseOutput, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
......@@ -323,11 +323,10 @@ class DiagonalGaussianDistribution(object):
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype)
sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
)
x = self.mean + self.std * sample
return x
......
......@@ -31,7 +31,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging, replace_example_docstring
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
......@@ -401,20 +401,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -33,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
......@@ -461,16 +461,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape
if isinstance(generator, list):
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
......
......@@ -23,6 +23,7 @@ from PIL import Image
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from .mel import Mel
......@@ -126,7 +127,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None:
noise = torch.randn(
noise = randn_tensor(
(
batch_size,
self.unet.in_channels,
......
......@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union
import torch
from ...utils import logging
from ...utils import logging, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
......@@ -100,16 +100,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
rand_device = "cpu" if self.device.type == "mps" else self.device
if isinstance(generator, list):
shape = (1,) + shape[1:]
audio = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
audio = torch.cat(audio, dim=0).to(self.device)
else:
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
# set step values
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
......
......@@ -16,7 +16,7 @@ from typing import List, Optional, Tuple, Union
import torch
from ...utils import deprecate
from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -103,17 +103,7 @@ class DDIMPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
rand_device = "cpu" if self.device.type == "mps" else self.device
if isinstance(generator, list):
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
......
......@@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union
import torch
from ...configuration_utils import FrozenDict
from ...utils import deprecate
from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -100,10 +100,10 @@ class DDPMPipeline(DiffusionPipeline):
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
......
......@@ -26,6 +26,7 @@ from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -143,20 +144,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if self.device.type == "mps" else self.device
if isinstance(generator, list):
latents_shape = (1,) + latents_shape[1:]
latents = [
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = torch.randn(
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
)
latents = latents.to(self.device)
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
......
......@@ -16,7 +16,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate
from ...utils import PIL_INTERPOLATION, deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -121,12 +121,7 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype
if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
latents = latents.to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
image = image.to(device=self.device, dtype=latents_dtype)
......
......@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -71,7 +72,7 @@ class LDMPipeline(DiffusionPipeline):
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
latents = torch.randn(
latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
......
......@@ -34,7 +34,7 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size)))
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
def forward(self, pixel_values):
clip_output = self.model(pixel_values=pixel_values)
......
......@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -300,20 +300,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -19,6 +19,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -72,11 +73,11 @@ class PNDMPipeline(DiffusionPipeline):
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
# Sample gaussian noise to begin loop
image = torch.randn(
image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
......
......@@ -22,7 +22,7 @@ import PIL
from ...models import UNet2DModel
from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -143,18 +143,8 @@ class RePaintPipeline(DiffusionPipeline):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
rand_device = "cpu" if self.device.type == "mps" else self.device
image_shape = original_image.shape
if isinstance(generator, list):
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
# set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
......
......@@ -18,6 +18,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -69,7 +70,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
......
......@@ -26,7 +26,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -76,7 +76,7 @@ def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta
# direction pointing to x_t
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
noise = std_dev_t * torch.randn(
noise = std_dev_t * randn_tensor(
clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
)
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise
......@@ -472,16 +472,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timestep
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape
if isinstance(generator, list):
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
clean_latents = init_latents
......
......@@ -30,7 +30,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, is_accelerate_available, logging, replace_example_docstring
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
......@@ -398,20 +398,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
......
......@@ -20,7 +20,6 @@ import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
......@@ -34,7 +33,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -381,16 +380,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else:
init_latents = torch.cat([init_latents], dim=0)
rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape
if isinstance(generator, list):
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment