Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d87134ad
Unverified
Commit
d87134ad
authored
Jul 21, 2025
by
Aryan
Committed by
GitHub
Jul 21, 2025
Browse files
[tests] Add test slices for Cosmos (#11955)
* test * try fix
parent
67a8ec8b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
16 deletions
+32
-16
tests/pipelines/cosmos/test_cosmos.py
tests/pipelines/cosmos/test_cosmos.py
+8
-4
tests/pipelines/cosmos/test_cosmos2_text2image.py
tests/pipelines/cosmos/test_cosmos2_text2image.py
+8
-4
tests/pipelines/cosmos/test_cosmos2_video2world.py
tests/pipelines/cosmos/test_cosmos2_video2world.py
+8
-4
tests/pipelines/cosmos/test_cosmos_video2world.py
tests/pipelines/cosmos/test_cosmos_video2world.py
+8
-4
No files found.
tests/pipelines/cosmos/test_cosmos.py
View file @
d87134ad
...
@@ -153,11 +153,15 @@ class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
...
@@ -153,11 +153,15 @@ class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
32
,
32
))
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
32
,
32
))
expected_video
=
torch
.
randn
(
9
,
3
,
32
,
32
)
max_diff
=
np
.
abs
(
generated_video
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.0
,
0.9686
,
0.8549
,
0.8078
,
0.0
,
0.8431
,
1.0
,
0.4863
,
0.7098
,
0.1098
,
0.8157
,
0.4235
,
0.6353
,
0.2549
,
0.5137
,
0.5333
])
# fmt: on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
def
test_callback_inputs
(
self
):
def
test_callback_inputs
(
self
):
sig
=
inspect
.
signature
(
self
.
pipeline_class
.
__call__
)
sig
=
inspect
.
signature
(
self
.
pipeline_class
.
__call__
)
...
...
tests/pipelines/cosmos/test_cosmos2_text2image.py
View file @
d87134ad
...
@@ -140,11 +140,15 @@ class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase
...
@@ -140,11 +140,15 @@ class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
image
=
pipe
(
**
inputs
).
images
image
=
pipe
(
**
inputs
).
images
generated_image
=
image
[
0
]
generated_image
=
image
[
0
]
self
.
assertEqual
(
generated_image
.
shape
,
(
3
,
32
,
32
))
self
.
assertEqual
(
generated_image
.
shape
,
(
3
,
32
,
32
))
expected_video
=
torch
.
randn
(
3
,
32
,
32
)
max_diff
=
np
.
abs
(
generated_image
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.451
,
0.451
,
0.4471
,
0.451
,
0.451
,
0.451
,
0.451
,
0.451
,
0.4784
,
0.4784
,
0.4784
,
0.4784
,
0.4784
,
0.4902
,
0.4588
,
0.5333
])
# fmt: on
generated_slice
=
generated_image
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
def
test_callback_inputs
(
self
):
def
test_callback_inputs
(
self
):
sig
=
inspect
.
signature
(
self
.
pipeline_class
.
__call__
)
sig
=
inspect
.
signature
(
self
.
pipeline_class
.
__call__
)
...
...
tests/pipelines/cosmos/test_cosmos2_video2world.py
View file @
d87134ad
...
@@ -147,11 +147,15 @@ class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCas
...
@@ -147,11 +147,15 @@ class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCas
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
32
,
32
))
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
32
,
32
))
expected_video
=
torch
.
randn
(
9
,
3
,
32
,
32
)
max_diff
=
np
.
abs
(
generated_video
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.451
,
0.451
,
0.4471
,
0.451
,
0.451
,
0.451
,
0.451
,
0.451
,
0.5098
,
0.5137
,
0.5176
,
0.5098
,
0.5255
,
0.5412
,
0.5098
,
0.5059
])
# fmt: on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
def
test_components_function
(
self
):
def
test_components_function
(
self
):
init_components
=
self
.
get_dummy_components
()
init_components
=
self
.
get_dummy_components
()
...
...
tests/pipelines/cosmos/test_cosmos_video2world.py
View file @
d87134ad
...
@@ -159,11 +159,15 @@ class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase
...
@@ -159,11 +159,15 @@ class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase
inputs
=
self
.
get_dummy_inputs
(
device
)
inputs
=
self
.
get_dummy_inputs
(
device
)
video
=
pipe
(
**
inputs
).
frames
video
=
pipe
(
**
inputs
).
frames
generated_video
=
video
[
0
]
generated_video
=
video
[
0
]
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
32
,
32
))
self
.
assertEqual
(
generated_video
.
shape
,
(
9
,
3
,
32
,
32
))
expected_video
=
torch
.
randn
(
9
,
3
,
32
,
32
)
max_diff
=
np
.
abs
(
generated_video
-
expected_video
).
max
()
# fmt: off
self
.
assertLessEqual
(
max_diff
,
1e10
)
expected_slice
=
torch
.
tensor
([
0.0
,
0.8275
,
0.7529
,
0.7294
,
0.0
,
0.6
,
1.0
,
0.3804
,
0.6667
,
0.0863
,
0.8784
,
0.5922
,
0.6627
,
0.2784
,
0.5725
,
0.7765
])
# fmt: on
generated_slice
=
generated_video
.
flatten
()
generated_slice
=
torch
.
cat
([
generated_slice
[:
8
],
generated_slice
[
-
8
:]])
self
.
assertTrue
(
torch
.
allclose
(
generated_slice
,
expected_slice
,
atol
=
1e-3
))
def
test_components_function
(
self
):
def
test_components_function
(
self
):
init_components
=
self
.
get_dummy_components
()
init_components
=
self
.
get_dummy_components
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment