"src/diffusers/pipelines/pipeline_flax_utils.py" did not exist on "1606eb994a754cc512dfa08d926e199851abc9be"
Unverified Commit 7d0c2729 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Match the generator device to the pipeline for DDPM and DDIM (#1222)

* Match the generator device to the pipeline for DDPM and DDIM

* style

* fix

* update values

* fix fast tests

* trigger slow tests

* deprecate

* last value fixes

* mps fixes
parent 3d98dc76
......@@ -11,10 +11,12 @@ import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import deprecate
from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from torchvision.transforms import (
CenterCrop,
Compose,
......@@ -28,6 +30,7 @@ from tqdm.auto import tqdm
logger = get_logger(__name__)
diffusers_version = version.parse(version.parse(__version__).base_version)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
......@@ -406,7 +409,11 @@ def main(args):
scheduler=noise_scheduler,
)
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
if diffusers_version < version.parse("0.8.0"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(
generator=generator,
......
......@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Optional, Tuple, Union
import torch
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate
class DDIMPipeline(DiffusionPipeline):
......@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline):
generated images.
"""
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_kwargs = {}
if accepts_use_clipped_model_output:
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
......@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
image = self.scheduler.step(
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
......
......@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline):
new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
......
......@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise
variance = 0
if t > 0:
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance
......
......@@ -19,7 +19,7 @@ import numpy as np
import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
......@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model
def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.to(device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
......@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch
@require_torch_gpu
class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256"
......@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_cifar10(self):
......@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......@@ -20,7 +20,7 @@ import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
......@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model
def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.to(device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
......@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
......@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
......@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch
@require_torch_gpu
class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
......@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
return model
def test_inference_superresolution(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device)
ldm.to(device)
ldm.set_progress_bar_config(disable=None)
init_image = self.dummy_image.to(torch_device)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
init_image = self.dummy_image.to(device)
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176])
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
......
......@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
......@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
......@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
......@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase):
def test_load_no_safety_checker_default_locally(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
......@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase):
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
......@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
......@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
......@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase):
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(generator=generator, output_type="numpy").images
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
......@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase):
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed)
ddim_image = ddim(
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality_batched(self, seed):
def test_ddpm_ddim_equality_batched(self):
seed = 0
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
......@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase):
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(seed)
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed)
generator = torch.Generator(device=torch_device).manual_seed(seed)
ddim_images = ddim(
batch_size=4,
batch_size=2,
generator=generator,
num_inference_steps=1000,
eta=1.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment