Unverified Commit 8ed08e42 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Deterministic torch randn] Allow tensors to be generated on CPU (#1902)



* [Deterministic torch randn] Allow tensors to be generated on CPU

* fix more

* up

* fix more

* up

* Update src/diffusers/utils/torch_utils.py
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* Apply suggestions from code review

* up

* up

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 0df83c79
......@@ -336,7 +336,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(h, w,) = (
img.shape[0],
img.shape[1],
)
......
......@@ -381,7 +381,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(h, w,) = (
img.shape[0],
img.shape[1],
)
......
......@@ -306,7 +306,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(h, w,) = (
img.shape[0],
img.shape[1],
)
......
......@@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model,
# unet utils
# <original>.time_embed -> <diffusers>.time_embedding
def unet_time_embeddings(checkpoint, original_unet_prefix):
diffusers_checkpoint = {}
......
......@@ -37,6 +37,7 @@ def rename_key(key):
# PyTorch => Flax #
#####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
......
......@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, torch_randn
from .text_proj import UnCLIPTextProjModel
......@@ -105,11 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
......@@ -29,7 +29,7 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging
from ...utils import is_accelerate_available, logging, torch_randn
from .text_proj import UnCLIPTextProjModel
......@@ -113,11 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
......@@ -20,7 +20,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import BaseOutput, torch_randn
from .scheduling_utils import SchedulerMixin
......@@ -273,15 +273,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise
variance = 0
if t > 0:
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance_noise = torch_randn(
model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
)
variance = self._get_variance(
t,
......
......@@ -64,6 +64,7 @@ from .import_utils import (
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION
from .torch_utils import torch_randn
if is_torch_available():
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch utilities: Utilities related to PyTorch
"""
from typing import List, Optional, Tuple, Union
from . import logging
from .import_utils import is_torch_available
if is_torch_available():
import torch
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def torch_randn(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
):
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
will always be created on CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
if generator is not None:
if generator.device != device and generator.device.type == "cpu":
rand_device = "cpu"
if device != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif generator.device.type != device.type and generator.device.type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.")
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
return latents
......@@ -382,7 +382,7 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline(
"horse",
num_images_per_prompt=1,
......
......@@ -480,7 +480,7 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
pipeline.set_progress_bar_config(disable=None)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline(
input_image,
num_images_per_prompt=1,
......
......@@ -96,6 +96,7 @@ def ignore_underscore(key):
def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
# If no key is provided, we use a noop.
def noop(x):
return x
......@@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
"""
Return the same `import_statement` but with objects properly sorted.
"""
# This inner function sort imports between [ ].
def _replace(match):
imports = match.groups()[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment