Unverified Commit 8ed08e42 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Deterministic torch randn] Allow tensors to be generated on CPU (#1902)



* [Deterministic torch randn] Allow tensors to be generated on CPU

* fix more

* up

* fix more

* up

* Update src/diffusers/utils/torch_utils.py
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>

* Apply suggestions from code review

* up

* up

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 0df83c79
...@@ -336,7 +336,7 @@ class TextualInversionDataset(Dataset): ...@@ -336,7 +336,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = ( (h, w,) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -381,7 +381,7 @@ class TextualInversionDataset(Dataset): ...@@ -381,7 +381,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = ( (h, w,) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -306,7 +306,7 @@ class TextualInversionDataset(Dataset): ...@@ -306,7 +306,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
h, w, = ( (h, w,) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )
......
...@@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, ...@@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model,
# unet utils # unet utils
# <original>.time_embed -> <diffusers>.time_embedding # <original>.time_embed -> <diffusers>.time_embedding
def unet_time_embeddings(checkpoint, original_unet_prefix): def unet_time_embeddings(checkpoint, original_unet_prefix):
diffusers_checkpoint = {} diffusers_checkpoint = {}
......
...@@ -37,6 +37,7 @@ def rename_key(key): ...@@ -37,6 +37,7 @@ def rename_key(key):
# PyTorch => Flax # # PyTorch => Flax #
##################### #####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
......
...@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput ...@@ -24,7 +24,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, torch_randn
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -105,11 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -105,11 +105,7 @@ class UnCLIPPipeline(DiffusionPipeline):
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
if device.type == "mps": latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
from ...models import UNet2DConditionModel, UNet2DModel from ...models import UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging, torch_randn
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -113,11 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -113,11 +113,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None: if latents is None:
if device.type == "mps": latents = torch_randn(shape, generator=generator, device=device, dtype=dtype)
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else: else:
if latents.shape != shape: if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput from ..utils import BaseOutput, torch_randn
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -273,15 +273,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -273,15 +273,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise # 6. Add noise
variance = 0 variance = 0
if t > 0: if t > 0:
device = model_output.device variance_noise = torch_randn(
if device.type == "mps": model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device
# randn does not work reproducibly on mps )
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = self._get_variance( variance = self._get_variance(
t, t,
......
...@@ -64,6 +64,7 @@ from .import_utils import ( ...@@ -64,6 +64,7 @@ from .import_utils import (
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION from .pil_utils import PIL_INTERPOLATION
from .torch_utils import torch_randn
if is_torch_available(): if is_torch_available():
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PyTorch utilities: Utilities related to PyTorch
"""
from typing import List, Optional, Tuple, Union
from . import logging
from .import_utils import is_torch_available
if is_torch_available():
import torch
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def torch_randn(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
):
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
will always be created on CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
if generator is not None:
if generator.device != device and generator.device.type == "cpu":
rand_device = "cpu"
if device != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif generator.device.type != device.type and generator.device.type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.")
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
return latents
...@@ -382,7 +382,7 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase): ...@@ -382,7 +382,7 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
pipeline = pipeline.to(torch_device) pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None) pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline( output = pipeline(
"horse", "horse",
num_images_per_prompt=1, num_images_per_prompt=1,
......
...@@ -480,7 +480,7 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase): ...@@ -480,7 +480,7 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
pipeline.set_progress_bar_config(disable=None) pipeline.set_progress_bar_config(disable=None)
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
output = pipeline( output = pipeline(
input_image, input_image,
num_images_per_prompt=1, num_images_per_prompt=1,
......
...@@ -96,6 +96,7 @@ def ignore_underscore(key): ...@@ -96,6 +96,7 @@ def ignore_underscore(key):
def sort_objects(objects, key=None): def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
# If no key is provided, we use a noop. # If no key is provided, we use a noop.
def noop(x): def noop(x):
return x return x
...@@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement): ...@@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
""" """
Return the same `import_statement` but with objects properly sorted. Return the same `import_statement` but with objects properly sorted.
""" """
# This inner function sort imports between [ ]. # This inner function sort imports between [ ].
def _replace(match): def _replace(match):
imports = match.groups()[0] imports = match.groups()[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment