"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d6c63bb956358f1990443a849ca250419a238b95"
Unverified Commit 7d0c2729 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Match the generator device to the pipeline for DDPM and DDIM (#1222)

* Match the generator device to the pipeline for DDPM and DDIM

* style

* fix

* update values

* fix fast tests

* trigger slow tests

* deprecate

* last value fixes

* mps fixes
parent 3d98dc76
...@@ -11,10 +11,12 @@ import torch.nn.functional as F ...@@ -11,10 +11,12 @@ import torch.nn.functional as F
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import deprecate
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
...@@ -28,6 +30,7 @@ from tqdm.auto import tqdm ...@@ -28,6 +30,7 @@ from tqdm.auto import tqdm
logger = get_logger(__name__) logger = get_logger(__name__)
diffusers_version = version.parse(version.parse(__version__).base_version)
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
...@@ -406,7 +409,11 @@ def main(args): ...@@ -406,7 +409,11 @@ def main(args):
scheduler=noise_scheduler, scheduler=noise_scheduler,
) )
generator = torch.manual_seed(0) deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
if diffusers_version < version.parse("0.8.0"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline( images = pipeline(
generator=generator, generator=generator,
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate
class DDIMPipeline(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
...@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline):
generated images. generated images.
""" """
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
device=self.device,
) )
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_kwargs = {}
if accepts_use_clipped_model_output:
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
for t in self.progress_bar(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t).sample model_output = self.unet(image, t).sample
...@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta # 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1 # do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample image = self.scheduler.step(
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
......
...@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline):
new_config["predict_epsilon"] = predict_epsilon new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config) self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator, generator=generator,
device=self.device,
) )
image = image.to(self.device)
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
......
...@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise # 6. Add noise
variance = 0 variance = 0
if t > 0: if t > 0:
noise = torch.randn( device = model_output.device
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator if device.type == "mps":
).to(model_output.device) # randn does not work reproducibly on mps
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance pred_prev_sample = pred_prev_sample + variance
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model return model
def test_inference(self): def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372) generator = torch.Generator(device=device).manual_seed(0)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array( expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
) )
tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
@slow @slow
@require_torch @require_torch_gpu
class DDIMPipelineIntegrationTests(unittest.TestCase): class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self): def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256" model_id = "google/ddpm-ema-bedroom-256"
...@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): ...@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3) assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_cifar10(self): def test_inference_cifar10(self):
...@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): ...@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None) ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy").images image = ddim(generator=generator, eta=0.0, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model return model
def test_inference(self): def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDPMScheduler() scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372) generator = torch.Generator(device=device).manual_seed(0)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array( expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
) )
tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_inference_predict_epsilon(self): def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove") deprecate("remove this test", "0.10.0", "remove")
...@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
if torch_device == "mps": if torch_device == "mps":
_ = ddpm(num_inference_steps=1) _ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch @require_torch_gpu
class DDPMPipelineIntegrationTests(unittest.TestCase): class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
...@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase): ...@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
return model return model
def test_inference_superresolution(self): def test_inference_superresolution(self):
device = "cpu"
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler() scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model vqvae = self.dummy_vq_model
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device) ldm.to(device)
ldm.set_progress_bar_config(disable=None) ldm.set_progress_bar_config(disable=None)
init_image = self.dummy_image.to(torch_device) init_image = self.dummy_image.to(device)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
generator = torch.manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176]) expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
@slow @slow
......
...@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline ...@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase): ...@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
) )
generator = torch.Generator(device=torch_device).manual_seed(0) pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator_2 = torch.Generator(device=torch_device).manual_seed(0) pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
...@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase): ...@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
) )
generator = torch.Generator(device=torch_device).manual_seed(0) pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
generator_2 = torch.Generator(device=torch_device).manual_seed(0) pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
...@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase): ...@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase):
def test_load_no_safety_checker_default_locally(self): def test_load_no_safety_checker_default_locally(self):
prompt = "hello" prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator = torch.Generator(device=torch_device).manual_seed(0) pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname) pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
generator_2 = torch.Generator(device=torch_device).manual_seed(0) pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3 assert np.max(np.abs(out - out_2)) < 1e-3
...@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase):
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device) new_ddpm.to(torch_device)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None) ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase):
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(generator=generator, output_type="numpy").images images = pipe(generator=generator, output_type="numpy").images
assert images.shape == (1, 32, 32, 3) assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray) assert isinstance(images, np.ndarray)
...@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase):
assert isinstance(images, list) assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image) assert isinstance(images[0], PIL.Image.Image)
# Make sure the test passes for different values of random seed def test_ddpm_ddim_equality_batched(self):
@parameterized.expand([(0,), (4,)]) seed = 0
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed)
ddim_image = ddim(
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality_batched(self, seed):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
...@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase):
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None) ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed) generator = torch.Generator(device=torch_device).manual_seed(seed)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed) generator = torch.Generator(device=torch_device).manual_seed(seed)
ddim_images = ddim( ddim_images = ddim(
batch_size=4, batch_size=2,
generator=generator, generator=generator,
num_inference_steps=1000, num_inference_steps=1000,
eta=1.0, eta=1.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment