Unverified Commit f427345a authored by Arsalan's avatar Arsalan Committed by GitHub
Browse files

Device agnostic testing (#5612)

* utils and test modifications to enable device agnostic testing

* device for manual seed in unet1d

* fix generator condition in vae test

* consistency changes to testing

* make style

* add device agnostic testing changes to source and one model test

* make dtype check fns private, log cuda fp16 case

* remove dtype checks from import utils, move to testing_utils

* adding tests for most model classes and one pipeline

* fix vae import
parent 6e221334
......@@ -17,7 +17,7 @@ from contextlib import contextmanager
from distutils.util import strtobool
from io import BytesIO, StringIO
from pathlib import Path
from typing import List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
......@@ -58,6 +58,17 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
if is_torch_available():
import torch
# Set a backend environment variable for any extra module import required for a custom accelerator
if "DIFFUSERS_TEST_BACKEND" in os.environ:
backend = os.environ["DIFFUSERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
to enable a specified backend.):\n{e}"
) from e
if "DIFFUSERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
try:
......@@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
)
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
test_case
)
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
test_case
)
def require_torch_accelerator_with_fp64(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
test_case
)
def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
is_torch_available() and backend_supports_training(torch_device),
"test requires accelerator with training support",
)(test_case)
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
......@@ -766,3 +807,139 @@ def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)
# Utils for custom and alternative accelerator devices
def _is_torch_fp16_available(device):
if not is_torch_available():
return False
import torch
device = torch.device(device)
try:
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
_ = x @ x
except Exception as e:
if device.type == "cuda":
raise ValueError(
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
)
return False
def _is_torch_fp64_available(device):
if not is_torch_available():
return False
import torch
try:
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
_ = x @ x
except Exception as e:
if device.type == "cuda":
raise ValueError(
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
)
return False
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available():
# Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
# Function definitions
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)
fn = dispatch_table[device]
# Some device agnostic functions return values. Need to guard against 'None' instead at
# user level
if fn is None:
return None
return fn(*args, **kwargs)
# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
if not is_torch_available():
return False
if device not in BACKEND_SUPPORTS_TRAINING:
device = "default"
return BACKEND_SUPPORTS_TRAINING[device]
# Guard for when Torch is not available
if is_torch_available():
# Update device function dict mapping
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
try:
# Try to import the function directly
spec_fn = getattr(device_spec_module, attribute_name)
device_fn_dict[torch_device] = spec_fn
except AttributeError as e:
# If the function doesn't exist, and there is no default, throw an error
if "default" not in device_fn_dict:
raise AttributeError(
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
) from e
if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
if not Path(device_spec_path).is_file():
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
try:
import_name = device_spec_path[: device_spec_path.index(".py")]
except ValueError as e:
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
device_spec_module = importlib.import_module(import_name)
try:
device_name = device_spec_module.DEVICE_NAME
except AttributeError:
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
raise ValueError(msg)
torch_device = device_name
# Add one entry here for each `BACKEND_*` dictionary.
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
......@@ -25,7 +25,11 @@ from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.utils.testing_utils import torch_device
from diffusers.utils.testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,
torch_device,
)
class EmbeddingsTests(unittest.TestCase):
......@@ -315,8 +319,7 @@ class ResnetBlock2DTests(unittest.TestCase):
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
......@@ -339,8 +342,7 @@ class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_cross_attention_dim(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
......@@ -363,8 +365,7 @@ class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_timestep(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
num_embeds_ada_norm = 5
......@@ -401,8 +402,7 @@ class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_dropout(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = (
......@@ -427,11 +427,10 @@ class Transformer2DModelTests(unittest.TestCase):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
@require_torch_accelerator_with_fp64
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
num_embed = 5
......
......@@ -35,6 +35,7 @@ from diffusers.utils.testing_utils import (
CaptureLogger,
require_python39_or_higher,
require_torch_2,
require_torch_accelerator_with_training,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
......@@ -536,7 +537,7 @@ class ModelTesterMixin:
self.assertEqual(output_1.shape, output_2.shape)
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
@require_torch_accelerator_with_training
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -553,7 +554,7 @@ class ModelTesterMixin:
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
@require_torch_accelerator_with_training
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -624,7 +625,7 @@ class ModelTesterMixin:
recursive_check(outputs_tuple, outputs_dict)
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
@require_torch_accelerator_with_training
def test_enable_disable_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
......
......@@ -21,7 +21,14 @@ import torch
from parameterized import parameterized
from diffusers import PriorTransformer
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
slow,
torch_all_close,
torch_device,
)
from .test_modeling_common import ModelTesterMixin
......@@ -157,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
@parameterized.expand(
[
......
......@@ -18,7 +18,12 @@ import unittest
import torch
from diffusers import UNet1DModel
from diffusers.utils.testing_utils import floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import (
backend_manual_seed,
floats_tensor,
slow,
torch_device,
)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
......@@ -103,8 +108,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
num_features = model.config.in_channels
seq_len = 16
......@@ -244,8 +248,7 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
num_features = value_function.config.in_channels
seq_len = 14
......
......@@ -24,6 +24,7 @@ from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
slow,
torch_all_close,
torch_device,
......@@ -153,7 +154,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
@require_torch_accelerator
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
......@@ -161,7 +162,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
@require_torch_accelerator
def test_from_pretrained_accelerate_wont_change_results(self):
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
......
......@@ -30,10 +30,15 @@ from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
torch_device,
......@@ -280,7 +285,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
== "XFormersAttnProcessor"
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -864,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
......@@ -882,6 +887,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
return model
@require_torch_gpu
def test_set_attention_slice_auto(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......@@ -901,6 +907,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
assert mem_bytes < 5 * 10**9
@require_torch_gpu
def test_set_attention_slice_max(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......@@ -920,6 +927,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
assert mem_bytes < 5 * 10**9
@require_torch_gpu
def test_set_attention_slice_int(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......@@ -939,6 +947,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
assert mem_bytes < 5 * 10**9
@require_torch_gpu
def test_set_attention_slice_list(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......@@ -975,7 +984,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
......@@ -1003,7 +1012,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
......@@ -1031,7 +1040,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator
@skip_mps
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
......@@ -1059,7 +1069,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
latents = self.get_latents(seed, fp16=True)
......@@ -1087,7 +1097,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator
@skip_mps
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
......@@ -1115,7 +1126,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
......@@ -1143,7 +1154,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
......
......@@ -31,10 +31,15 @@ from diffusers import (
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
torch_device,
......@@ -157,7 +162,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_training(self):
pass
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
@require_torch_accelerator_with_training
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -213,10 +218,12 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model = model.to(torch_device)
model.eval()
if torch_device == "mps":
generator = torch.manual_seed(0)
# Keep generator on CPU for non-CUDA devices to compare outputs with CPU result tensors
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
if torch_device != "mps":
generator = torch.Generator(device=generator_device).manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = torch.randn(
1,
......@@ -247,7 +254,7 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
-9.8644e-03,
]
)
elif torch_device == "cpu":
elif generator_device == "cpu":
expected_output_slice = torch.tensor(
[
-0.1352,
......@@ -478,7 +485,7 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
......@@ -558,7 +565,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
......@@ -580,9 +587,10 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
return model
def get_generator(self, seed=0):
if torch_device == "mps":
return torch.manual_seed(seed)
return torch.Generator(device=torch_device).manual_seed(seed)
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
if torch_device != "mps":
return torch.Generator(device=generator_device).manual_seed(seed)
return torch.manual_seed(seed)
@parameterized.expand(
[
......@@ -623,7 +631,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_stable_diffusion_fp16(self, seed, expected_slice):
model = self.get_sd_vae_model(fp16=True)
image = self.get_sd_image(seed, fp16=True)
......@@ -677,7 +685,8 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator
@skip_mps
def test_stable_diffusion_decode(self, seed, expected_slice):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
......@@ -700,7 +709,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator_with_fp16
def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
......@@ -811,7 +820,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
......@@ -832,9 +841,10 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
return model
def get_generator(self, seed=0):
if torch_device == "mps":
return torch.manual_seed(seed)
return torch.Generator(device=torch_device).manual_seed(seed)
generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda"
if torch_device != "mps":
return torch.Generator(device=generator_device).manual_seed(seed)
return torch.manual_seed(seed)
@parameterized.expand(
[
......@@ -905,7 +915,8 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
@require_torch_accelerator
@skip_mps
def test_stable_diffusion_decode(self, seed, expected_slice):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
......
......@@ -18,7 +18,12 @@ import unittest
import torch
from diffusers import VQModel
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from diffusers.utils.testing_utils import (
backend_manual_seed,
enable_full_determinism,
floats_tensor,
torch_device,
)
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
......@@ -80,8 +85,7 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model.to(torch_device).eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
......
......@@ -12,12 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Tuple
import torch
from diffusers.utils.testing_utils import floats_tensor, require_torch, torch_all_close, torch_device
from diffusers.utils.testing_utils import (
floats_tensor,
require_torch,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
......@@ -104,7 +109,7 @@ class UNetBlockTesterMixin:
expected_slice = torch.tensor(expected_slice).to(torch_device)
assert torch_all_close(output_slice.flatten(), expected_slice, atol=5e-3)
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
@require_torch_accelerator_with_training
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.block_class(**init_dict)
......
......@@ -34,11 +34,14 @@ from diffusers import (
)
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
enable_full_determinism,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_torch_accelerator,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
......@@ -128,10 +131,12 @@ class StableDiffusion2PipelineFastTests(
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
generator_device = "cpu" if not device.startswith("cuda") else "cuda"
if not str(device).startswith("mps"):
generator = torch.Generator(device=generator_device).manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
generator = torch.manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
......@@ -299,15 +304,21 @@ class StableDiffusion2PipelineFastTests(
@slow
@require_torch_gpu
@require_torch_accelerator
@skip_mps
class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
if not str(device).startswith("mps"):
generator = torch.Generator(device=_generator_device).manual_seed(seed)
else:
generator = torch.manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
......@@ -361,6 +372,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.10440, 0.13115, 0.11100, 0.10141, 0.11440, 0.07215, 0.11332, 0.09693, 0.10006])
assert np.abs(image_slice - expected_slice).max() < 3e-3
@require_torch_gpu
def test_stable_diffusion_attention_slicing(self):
torch.cuda.reset_peak_memory_stats()
pipe = StableDiffusionPipeline.from_pretrained(
......@@ -432,6 +444,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
assert callback_fn.has_been_called
assert number_of_steps == inputs["num_inference_steps"]
@require_torch_gpu
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......@@ -452,6 +465,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
# make sure that less than 2.8 GB is allocated
assert mem_bytes < 2.8 * 10**9
@require_torch_gpu
def test_stable_diffusion_pipeline_with_model_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......@@ -511,15 +525,21 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
@skip_mps
class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
if not str(device).startswith("mps"):
generator = torch.Generator(device=_generator_device).manual_seed(seed)
else:
generator = torch.manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment