Unverified Commit 7c6e9ef4 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Fix how compiler mixin classes are used (#11680)

* fix how compiler tester mixins are used.

* propagate

* more
parent f46abfe4
......@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict
class FluxTransformerTests(
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
):
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
......@@ -169,3 +167,17 @@ class FluxTransformerTests(
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
......@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -93,7 +93,14 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin,
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -161,7 +168,14 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompi
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -227,9 +241,14 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileT
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
):
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -295,3 +314,10 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
......@@ -26,7 +26,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -81,3 +81,10 @@ class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.Te
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
......@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
......@@ -82,3 +82,10 @@ class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
......@@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionModelTests(
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
......@@ -1147,6 +1145,20 @@ class UNet2DConditionModelTests(
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment