Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7c6e9ef4
Unverified
Commit
7c6e9ef4
authored
Jun 09, 2025
by
Sayak Paul
Committed by
GitHub
Jun 09, 2025
Browse files
[tests] Fix how compiler mixin classes are used (#11680)
* fix how compiler tester mixins are used. * propagate * more
parent
f46abfe4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
78 additions
and
14 deletions
+78
-14
tests/models/transformers/test_models_transformer_flux.py
tests/models/transformers/test_models_transformer_flux.py
+15
-3
tests/models/transformers/test_models_transformer_hunyuan_video.py
...els/transformers/test_models_transformer_hunyuan_video.py
+32
-6
tests/models/transformers/test_models_transformer_ltx.py
tests/models/transformers/test_models_transformer_ltx.py
+8
-1
tests/models/transformers/test_models_transformer_wan.py
tests/models/transformers/test_models_transformer_wan.py
+8
-1
tests/models/unets/test_models_unet_2d_condition.py
tests/models/unets/test_models_unet_2d_condition.py
+15
-3
No files found.
tests/models/transformers/test_models_transformer_flux.py
View file @
7c6e9ef4
...
...
@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return
ip_state_dict
class
FluxTransformerTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
LoraHotSwappingForModelTesterMixin
,
unittest
.
TestCase
):
class
FluxTransformerTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
FluxTransformer2DModel
main_input_name
=
"hidden_states"
# We override the items here because the transformer under consideration is small.
...
...
@@ -169,3 +167,17 @@ class FluxTransformerTests(
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"FluxTransformer2DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
FluxTransformerCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
FluxTransformer2DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
FluxTransformerTests
().
prepare_init_args_and_inputs_for_common
()
class
FluxTransformerLoRAHotSwapTests
(
LoraHotSwappingForModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
FluxTransformer2DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
FluxTransformerTests
().
prepare_init_args_and_inputs_for_common
()
tests/models/transformers/test_models_transformer_hunyuan_video.py
View file @
7c6e9ef4
...
...
@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism
()
class
HunyuanVideoTransformer3DTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
unittest
.
TestCase
):
class
HunyuanVideoTransformer3DTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
...
...
@@ -93,7 +93,14 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin,
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
HunyuanSkyreelsImageToVideoTransformer3DTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
unittest
.
TestCase
):
class
HunyuanTransformerCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
HunyuanVideoTransformer3DTests
().
prepare_init_args_and_inputs_for_common
()
class
HunyuanSkyreelsImageToVideoTransformer3DTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
...
...
@@ -161,7 +168,14 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompi
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
HunyuanVideoImageToVideoTransformer3DTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
unittest
.
TestCase
):
class
HunyuanSkyreelsImageToVideoCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
HunyuanSkyreelsImageToVideoTransformer3DTests
().
prepare_init_args_and_inputs_for_common
()
class
HunyuanVideoImageToVideoTransformer3DTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
...
...
@@ -227,9 +241,14 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileT
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
HunyuanVideoTokenReplaceImageToVideoTransformer3DTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
unittest
.
TestCase
):
class
HunyuanImageToVideoCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
HunyuanVideoImageToVideoTransformer3DTests
().
prepare_init_args_and_inputs_for_common
()
class
HunyuanVideoTokenReplaceImageToVideoTransformer3DTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
...
...
@@ -295,3 +314,10 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"HunyuanVideoTransformer3DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
HunyuanVideoTokenReplaceCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
HunyuanVideoTransformer3DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
HunyuanVideoTokenReplaceImageToVideoTransformer3DTests
().
prepare_init_args_and_inputs_for_common
()
tests/models/transformers/test_models_transformer_ltx.py
View file @
7c6e9ef4
...
...
@@ -26,7 +26,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism
()
class
LTXTransformerTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
unittest
.
TestCase
):
class
LTXTransformerTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
LTXVideoTransformer3DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
...
...
@@ -81,3 +81,10 @@ class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.Te
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"LTXVideoTransformer3DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
LTXTransformerCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
LTXVideoTransformer3DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
LTXTransformerTests
().
prepare_init_args_and_inputs_for_common
()
tests/models/transformers/test_models_transformer_wan.py
View file @
7c6e9ef4
...
...
@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism
()
class
WanTransformer3DTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
unittest
.
TestCase
):
class
WanTransformer3DTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
WanTransformer3DModel
main_input_name
=
"hidden_states"
uses_custom_attn_processor
=
True
...
...
@@ -82,3 +82,10 @@ class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.
def
test_gradient_checkpointing_is_applied
(
self
):
expected_set
=
{
"WanTransformer3DModel"
}
super
().
test_gradient_checkpointing_is_applied
(
expected_set
=
expected_set
)
class
WanTransformerCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
WanTransformer3DModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
WanTransformer3DTests
().
prepare_init_args_and_inputs_for_common
()
tests/models/unets/test_models_unet_2d_condition.py
View file @
7c6e9ef4
...
...
@@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return
custom_diffusion_attn_procs
class
UNet2DConditionModelTests
(
ModelTesterMixin
,
TorchCompileTesterMixin
,
LoraHotSwappingForModelTesterMixin
,
UNetTesterMixin
,
unittest
.
TestCase
):
class
UNet2DConditionModelTests
(
ModelTesterMixin
,
UNetTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet2DConditionModel
main_input_name
=
"sample"
# We override the items here because the unet under consideration is small.
...
...
@@ -1147,6 +1145,20 @@ class UNet2DConditionModelTests(
assert
"Using the `save_attn_procs()` method has been deprecated"
in
warning_message
class
UNet2DConditionModelCompileTests
(
TorchCompileTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet2DConditionModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
UNet2DConditionModelTests
().
prepare_init_args_and_inputs_for_common
()
class
UNet2DConditionModelLoRAHotSwapTests
(
LoraHotSwappingForModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNet2DConditionModel
def
prepare_init_args_and_inputs_for_common
(
self
):
return
UNet2DConditionModelTests
().
prepare_init_args_and_inputs_for_common
()
@
slow
class
UNet2DConditionModelIntegrationTests
(
unittest
.
TestCase
):
def
get_file_format
(
self
,
seed
,
shape
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment