Unverified Commit 7c6e9ef4 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Fix how compiler mixin classes are used (#11680)

* fix how compiler tester mixins are used.

* propagate

* more
parent f46abfe4
...@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model): ...@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict return ip_state_dict
class FluxTransformerTests( class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
):
model_class = FluxTransformer2DModel model_class = FluxTransformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small. # We override the items here because the transformer under consideration is small.
...@@ -169,3 +167,17 @@ class FluxTransformerTests( ...@@ -169,3 +167,17 @@ class FluxTransformerTests(
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"} expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
...@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin ...@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism() enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -93,7 +93,14 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, ...@@ -93,7 +93,14 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin,
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -161,7 +168,14 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompi ...@@ -161,7 +168,14 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompi
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -227,9 +241,14 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileT ...@@ -227,9 +241,14 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileT
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase model_class = HunyuanVideoTransformer3DModel
):
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -295,3 +314,10 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests( ...@@ -295,3 +314,10 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
...@@ -26,7 +26,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin ...@@ -26,7 +26,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism() enable_full_determinism()
class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -81,3 +81,10 @@ class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.Te ...@@ -81,3 +81,10 @@ class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.Te
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"} expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
...@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin ...@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism() enable_full_determinism()
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel model_class = WanTransformer3DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
uses_custom_attn_processor = True uses_custom_attn_processor = True
...@@ -82,3 +82,10 @@ class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest. ...@@ -82,3 +82,10 @@ class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"} expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
...@@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): ...@@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs return custom_diffusion_attn_procs
class UNet2DConditionModelTests( class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = UNet2DConditionModel model_class = UNet2DConditionModel
main_input_name = "sample" main_input_name = "sample"
# We override the items here because the unet under consideration is small. # We override the items here because the unet under consideration is small.
...@@ -1147,6 +1145,20 @@ class UNet2DConditionModelTests( ...@@ -1147,6 +1145,20 @@ class UNet2DConditionModelTests(
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape): def get_file_format(self, seed, shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment