Unverified Commit d6bf268a authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable dit integration cases on xpu (#11523)



* enable dit integration test on XPU
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

* fix style
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarYao Matrix <matrix.yao@intel.com>
parent 3c0a0129
......@@ -21,7 +21,15 @@ import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_torch_accelerator,
torch_device,
)
from ..pipeline_params import (
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
......@@ -107,23 +115,23 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_dit_256(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.to("cuda")
pipe.to(torch_device)
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
......@@ -139,7 +147,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
def test_dit_512(self):
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
pipe.to(torch_device)
words = ["vase", "umbrella"]
ids = pipe.get_label_ids(words)
......@@ -152,4 +160,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
)
assert np.abs((expected_image - image).max()) < 1e-1
expected_slice = expected_image.flatten()
output_slice = image.flatten()
assert numpy_cosine_similarity_distance(expected_slice, output_slice) < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment