Unverified Commit 178d32de authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[tests] Add test slices for Wan (#11920)

* update

* fix wan vace test slice

* test

* fix
parent ef1e6287
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import gc import gc
import unittest import unittest
import numpy as np
import torch import torch
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
...@@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import ( ...@@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import ( from ..test_pipelines_common import PipelineTesterMixin
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
...@@ -127,11 +124,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -127,11 +124,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported") @unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import unittest import unittest
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from transformers import ( from transformers import (
...@@ -147,11 +146,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -147,11 +146,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported") @unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
...@@ -162,7 +165,25 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -162,7 +165,25 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pass pass
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests): class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WanImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
vae = AutoencoderKLWan( vae = AutoencoderKLWan(
...@@ -247,3 +268,32 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests): ...@@ -247,3 +268,32 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
"output_type": "pt", "output_type": "pt",
} }
return inputs return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
# fmt: off
expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import unittest import unittest
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
...@@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (17, 3, 16, 16)) self.assertEqual(generated_video.shape, (17, 3, 16, 16))
expected_video = torch.randn(17, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
# fmt:on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported") @unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment