Unverified Commit 67a8ec8b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[tests] Add test slices for Hunyuan Video (#11954)

update
parent cde02b06
...@@ -229,12 +229,19 @@ class HunyuanVideoImageToVideoPipelineFastTests( ...@@ -229,12 +229,19 @@ class HunyuanVideoImageToVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
# NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline # NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
self.assertEqual(generated_video.shape, (5, 3, 16, 16)) self.assertEqual(generated_video.shape, (5, 3, 16, 16))
expected_video = torch.randn(5, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self): def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__) sig = inspect.signature(self.pipeline_class.__call__)
......
...@@ -192,11 +192,18 @@ class HunyuanSkyreelsImageToVideoPipelineFastTests( ...@@ -192,11 +192,18 @@ class HunyuanSkyreelsImageToVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.5832, 0.5498, 0.4839, 0.4744, 0.4515, 0.4832, 0.496, 0.563, 0.5918, 0.5979, 0.5101, 0.6168, 0.6613, 0.536, 0.55, 0.5775])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self): def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__) sig = inspect.signature(self.pipeline_class.__call__)
......
...@@ -26,10 +26,7 @@ from diffusers import ( ...@@ -26,10 +26,7 @@ from diffusers import (
HunyuanVideoPipeline, HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import enable_full_determinism, torch_device
enable_full_determinism,
torch_device,
)
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FasterCacheTesterMixin, FasterCacheTesterMixin,
...@@ -206,11 +203,18 @@ class HunyuanVideoPipelineFastTests( ...@@ -206,11 +203,18 @@ class HunyuanVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16)) self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.3946, 0.4649, 0.3196, 0.4569, 0.3312, 0.3687, 0.3216, 0.3972, 0.4469, 0.3888, 0.3929, 0.3802, 0.3479, 0.3888, 0.3825, 0.3542])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self): def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__) sig = inspect.signature(self.pipeline_class.__call__)
......
...@@ -227,11 +227,18 @@ class HunyuanVideoFramepackPipelineFastTests( ...@@ -227,11 +227,18 @@ class HunyuanVideoFramepackPipelineFastTests(
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames video = pipe(**inputs).frames
generated_video = video[0] generated_video = video[0]
self.assertEqual(generated_video.shape, (13, 3, 32, 32)) self.assertEqual(generated_video.shape, (13, 3, 32, 32))
expected_video = torch.randn(13, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max() # fmt: off
self.assertLessEqual(max_diff, 1e10) expected_slice = torch.tensor([0.363, 0.3384, 0.3426, 0.3512, 0.3372, 0.3276, 0.417, 0.4061, 0.5221, 0.467, 0.4813, 0.4556, 0.4107, 0.3945, 0.4049, 0.4551])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self): def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__) sig = inspect.signature(self.pipeline_class.__call__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment