Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
178d32de
Unverified
Commit
178d32de
authored
Jul 23, 2025
by
Aryan
Committed by
GitHub
Jul 23, 2025
Browse files
[tests] Add test slices for Wan (#11920)
* update * fix wan vace test slice * test * fix
parent
ef1e6287
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
19 deletions
+73
-19
tests/pipelines/wan/test_wan.py
tests/pipelines/wan/test_wan.py
+9
-8
tests/pipelines/wan/test_wan_image_to_video.py
tests/pipelines/wan/test_wan_image_to_video.py
+56
-6
tests/pipelines/wan/test_wan_video_to_video.py
tests/pipelines/wan/test_wan_video_to_video.py
+8
-5
No files found.
tests/pipelines/wan/test_wan.py
View file @
178d32de
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
import
gc
import
gc
import
unittest
import
unittest
import
numpy
as
np
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
transformers
import
AutoTokenizer
,
T5EncoderModel
...
@@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import (
...
@@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import (
)
)
from
..pipeline_params
import
TEXT_TO_IMAGE_BATCH_PARAMS
,
TEXT_TO_IMAGE_IMAGE_PARAMS
,
TEXT_TO_IMAGE_PARAMS
from
..pipeline_params
import
TEXT_TO_IMAGE_BATCH_PARAMS
,
TEXT_TO_IMAGE_IMAGE_PARAMS
,
TEXT_TO_IMAGE_PARAMS
from
..test_pipelines_common
import
(
from
..test_pipelines_common
import
PipelineTesterMixin
PipelineTesterMixin
,
)
enable_full_determinism
()
enable_full_determinism
()
...
@@ -127,11 +124,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -127,11 +124,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
16
,
16
))
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
16
,
16
))
expected_video
=
torch
.
randn
(
9
,
3
,
16
,
16
)
max_diff
=
np
.
abs
(
generated_video
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.4525
,
0.452
,
0.4485
,
0.4534
,
0.4524
,
0.4529
,
0.454
,
0.453
,
0.5127
,
0.5326
,
0.5204
,
0.5253
,
0.5439
,
0.5424
,
0.5133
,
0.5078
])
# fmt: on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
@
unittest
.
skip
(
"Test not supported"
)
@
unittest
.
skip
(
"Test not supported"
)
def
test_attention_slicing_forward_pass
(
self
):
def
test_attention_slicing_forward_pass
(
self
):
...
...
tests/pipelines/wan/test_wan_image_to_video.py
View file @
178d32de
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
unittest
import
unittest
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
(
from
transformers
import
(
...
@@ -147,11 +146,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -147,11 +146,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
16
,
16
))
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
16
,
16
))
expected_video
=
torch
.
randn
(
9
,
3
,
16
,
16
)
max_diff
=
np
.
abs
(
generated_video
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.4525
,
0.4525
,
0.4497
,
0.4536
,
0.452
,
0.4529
,
0.454
,
0.4535
,
0.5072
,
0.5527
,
0.5165
,
0.5244
,
0.5481
,
0.5282
,
0.5208
,
0.5214
])
# fmt: on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
@
unittest
.
skip
(
"Test not supported"
)
@
unittest
.
skip
(
"Test not supported"
)
def
test_attention_slicing_forward_pass
(
self
):
def
test_attention_slicing_forward_pass
(
self
):
...
@@ -162,7 +165,25 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -162,7 +165,25 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pass
pass
class
WanFLFToVideoPipelineFastTests
(
WanImageToVideoPipelineFastTests
):
class
WanFLFToVideoPipelineFastTests
(
PipelineTesterMixin
,
unittest
.
TestCase
):
pipeline_class
=
WanImageToVideoPipeline
params
=
TEXT_TO_IMAGE_PARAMS
-
{
"cross_attention_kwargs"
,
"height"
,
"width"
}
batch_params
=
TEXT_TO_IMAGE_BATCH_PARAMS
image_params
=
TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params
=
TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params
=
frozenset
(
[
"num_inference_steps"
,
"generator"
,
"latents"
,
"return_dict"
,
"callback_on_step_end"
,
"callback_on_step_end_tensor_inputs"
,
]
)
test_xformers_attention
=
False
supports_dduf
=
False
def
get_dummy_components
(
self
):
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
vae
=
AutoencoderKLWan
(
vae
=
AutoencoderKLWan
(
...
@@ -247,3 +268,32 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
...
@@ -247,3 +268,32 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
"output_type"
:
"pt"
,
"output_type"
:
"pt"
,
}
}
return
inputs
return
inputs
def
test_inference
(
self
):
device
=
"cpu"
components
=
self
.
get_dummy_components
()
pipe
=
self
.
pipeline_class
(
**
components
)
pipe
.
to
(
device
)
pipe
.
set_progress_bar_config
(
disable
=
None
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
16
,
16
))
# fmt: off
expected_slice
=
torch
.
tensor
([
0.4531
,
0.4527
,
0.4498
,
0.4542
,
0.4526
,
0.4527
,
0.4534
,
0.4534
,
0.5061
,
0.5185
,
0.5283
,
0.5181
,
0.5309
,
0.5365
,
0.5113
,
0.5244
])
# fmt: on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
@
unittest
.
skip
(
"Test not supported"
)
def
test_attention_slicing_forward_pass
(
self
):
pass
@
unittest
.
skip
(
"TODO: revisit failing as it requires a very high threshold to pass"
)
def
test_inference_batch_single_identical
(
self
):
pass
tests/pipelines/wan/test_wan_video_to_video.py
View file @
178d32de
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
unittest
import
unittest
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
AutoTokenizer
,
T5EncoderModel
from
transformers
import
AutoTokenizer
,
T5EncoderModel
...
@@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
...
@@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
17
,
3
,
16
,
16
))
self
.
assertEqual
(
generated_video
.
shape
,
(
17
,
3
,
16
,
16
))
expected_video
=
torch
.
randn
(
17
,
3
,
16
,
16
)
max_diff
=
np
.
abs
(
generated_video
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.4522
,
0.4534
,
0.4532
,
0.4553
,
0.4526
,
0.4538
,
0.4533
,
0.4547
,
0.513
,
0.5176
,
0.5286
,
0.4958
,
0.4955
,
0.5381
,
0.5154
,
0.5195
])
# fmt:on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
@
unittest
.
skip
(
"Test not supported"
)
@
unittest
.
skip
(
"Test not supported"
)
def
test_attention_slicing_forward_pass
(
self
):
def
test_attention_slicing_forward_pass
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment