Unverified Commit 7d0c2729 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Match the generator device to the pipeline for DDPM and DDIM (#1222)

* Match the generator device to the pipeline for DDPM and DDIM

* style

* fix

* update values

* fix fast tests

* trigger slow tests

* deprecate

* last value fixes

* mps fixes
parent 3d98dc76
...@@ -11,10 +11,12 @@ import torch.nn.functional as F ...@@ -11,10 +11,12 @@ import torch.nn.functional as F
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import deprecate
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
...@@ -28,6 +30,7 @@ from tqdm.auto import tqdm ...@@ -28,6 +30,7 @@ from tqdm.auto import tqdm
logger = get_logger(__name__) logger = get_logger(__name__)
diffusers_version = version.parse(version.parse(__version__).base_version)
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
...@@ -406,7 +409,11 @@ def main(args): ...@@ -406,7 +409,11 @@ def main(args):
scheduler=noise_scheduler, scheduler=noise_scheduler,
) )
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
if diffusers_version < version.parse("0.8.0"):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline( images = pipeline(
generator=generator, generator=generator,
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate
class DDIMPipeline(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
...@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline):
generated images. generated images.
""" """
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
device=self.device,
) )
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_kwargs = {}
if accepts_use_clipped_model_output:
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t).sample model_output = self.unet(image, t).sample
...@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta # 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample image = self.scheduler.step(
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline):
new_config["predict_epsilon"] = predict_epsilon new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config) self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
device=self.device,
) )
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise # 6. Add noise
variance = 0 variance = 0
if t > 0: if t > 0:
noise = torch.randn( device = model_output.device
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator if device.type == "mps":
).to(model_output.device) # randn does not work reproducibly on mps
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model return model
def test_inference(self): def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372) generator = torch.Generator(device=device).manual_seed(0)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array( expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
) )
tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
@slow @slow
@require_torch @require_torch_gpu
class DDIMPipelineIntegrationTests(unittest.TestCase): class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self): def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256" model_id = "google/ddpm-ema-bedroom-256"
...@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): ...@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3) assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_cifar10(self): def test_inference_cifar10(self):
...@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): ...@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None) ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy").images image = ddim(generator=generator, eta=0.0, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model return model
def test_inference(self): def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDPMScheduler() scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372) generator = torch.Generator(device=device).manual_seed(0)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array( expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
) )
tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_inference_predict_epsilon(self): def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.10.0", "remove")
...@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
if torch_device == "mps": if torch_device == "mps":
_ = ddpm(num_inference_steps=1) _ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch @require_torch_gpu
class DDPMPipelineIntegrationTests(unittest.TestCase): class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
...@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase): ...@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
return model return model
def test_inference_superresolution(self): def test_inference_superresolution(self):
device = "cpu"
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model vqvae = self.dummy_vq_model
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device) ldm.to(device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
init_image = self.dummy_image.to(torch_device) init_image = self.dummy_image.to(device)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176]) expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
@slow @slow
......
...@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase): ...@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
) )
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator_2 = torch.Generator(device=torch_device).manual_seed(0) pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
...@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase): ...@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
) )
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
generator_2 = torch.Generator(device=torch_device).manual_seed(0) pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
...@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase): ...@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase):
def test_load_no_safety_checker_default_locally(self): def test_load_no_safety_checker_default_locally(self):
prompt = "hello" prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
generator_2 = torch.Generator(device=torch_device).manual_seed(0) pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
...@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase):
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device) new_ddpm.to(torch_device)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None) ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase):
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(generator=generator, output_type="numpy").images images = pipe(generator=generator, output_type="numpy").images
assert images.shape == (1, 32, 32, 3) assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray) assert isinstance(images, np.ndarray)
...@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase):
assert isinstance(images, list) assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image) assert isinstance(images[0], PIL.Image.Image)
# Make sure the test passes for different values of random seed def test_ddpm_ddim_equality_batched(self):
@parameterized.expand([(0,), (4,)]) seed = 0
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed)
ddim_image = ddim(
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality_batched(self, seed):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
...@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase):
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None) ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed) generator = torch.Generator(device=torch_device).manual_seed(seed)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed) generator = torch.Generator(device=torch_device).manual_seed(seed)
ddim_images = ddim( ddim_images = ddim(
batch_size=4, batch_size=2,
generator=generator, generator=generator,
num_inference_steps=1000, num_inference_steps=1000,
eta=1.0, eta=1.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment