Unverified Commit 36acdd75 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] skip tests properly with `unittest.skip()` (#10527)

* skip tests properly.

* more

* more
parent e7db062e
...@@ -65,9 +65,11 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -65,9 +65,11 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
@unittest.skip("Test not supported.")
def test_forward_signature(self): def test_forward_signature(self):
pass pass
@unittest.skip("Test not supported.")
def test_training(self): def test_training(self):
pass pass
......
...@@ -51,9 +51,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -51,9 +51,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def output_shape(self): def output_shape(self):
return (4, 14, 16) return (4, 14, 16)
@unittest.skip("Test not supported.")
def test_ema_training(self): def test_ema_training(self):
pass pass
@unittest.skip("Test not supported.")
def test_training(self): def test_training(self):
pass pass
...@@ -126,6 +128,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -126,6 +128,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet # Not implemented yet for this UNet
pass pass
...@@ -205,9 +208,11 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -205,9 +208,11 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Test not supported.")
def test_ema_training(self): def test_ema_training(self):
pass pass
@unittest.skip("Test not supported.")
def test_training(self): def test_training(self):
pass pass
...@@ -265,6 +270,7 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -265,6 +270,7 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet # Not implemented yet for this UNet
pass pass
...@@ -383,6 +383,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -383,6 +383,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
# not required for this model # not required for this model
pass pass
......
...@@ -320,6 +320,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes ...@@ -320,6 +320,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert output.shape == output_mix_time.shape assert output.shape == output_mix_time.shape
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass pass
...@@ -232,8 +232,10 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -232,8 +232,10 @@ class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase
def test_float16_inference(self): def test_float16_inference(self):
super().test_float16_inference() super().test_float16_inference()
@unittest.skip(reason="Test not supported.")
def test_callback_inputs(self): def test_callback_inputs(self):
pass pass
@unittest.skip(reason="Test not supported.")
def test_callback_cfg(self): def test_callback_cfg(self):
pass pass
import unittest
import torch import torch
from diffusers import DDIMInverseScheduler from diffusers import DDIMInverseScheduler
...@@ -95,6 +97,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -95,6 +97,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
@unittest.skip("Test not supported.")
def test_add_noise_device(self): def test_add_noise_device(self):
pass pass
......
import tempfile import tempfile
import unittest
import torch import torch
...@@ -57,6 +58,7 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): ...@@ -57,6 +58,7 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
......
import tempfile import tempfile
import unittest
import torch import torch
...@@ -67,6 +68,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -67,6 +68,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
......
import tempfile import tempfile
import unittest
import torch import torch
...@@ -65,6 +66,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -65,6 +66,7 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
......
...@@ -3,9 +3,7 @@ import unittest ...@@ -3,9 +3,7 @@ import unittest
import torch import torch
from diffusers import ( from diffusers import EDMDPMSolverMultistepScheduler
EDMDPMSolverMultistepScheduler,
)
from .test_schedulers import SchedulerCommonTest from .test_schedulers import SchedulerCommonTest
...@@ -63,6 +61,7 @@ class EDMDPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -63,6 +61,7 @@ class EDMDPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
...@@ -258,5 +257,6 @@ class EDMDPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -258,5 +257,6 @@ class EDMDPMSolverMultistepSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(scheduler.config.num_train_timesteps) scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps) == scheduler.num_inference_steps assert len(scheduler.timesteps) == scheduler.num_inference_steps
@unittest.skip("Test not supported.")
def test_trained_betas(self): def test_trained_betas(self):
pass pass
...@@ -675,6 +675,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): ...@@ -675,6 +675,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
......
import tempfile import tempfile
import unittest
import torch import torch
...@@ -50,6 +51,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): ...@@ -50,6 +51,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
......
import tempfile import tempfile
import unittest
import torch import torch
...@@ -53,6 +54,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -53,6 +54,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@unittest.skip("Test not supported.")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
pass pass
......
import unittest
import torch import torch
from diffusers import UnCLIPScheduler from diffusers import UnCLIPScheduler
...@@ -130,8 +132,10 @@ class UnCLIPSchedulerTest(SchedulerCommonTest): ...@@ -130,8 +132,10 @@ class UnCLIPSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 258.2044983) < 1e-2 assert abs(result_sum.item() - 258.2044983) < 1e-2
assert abs(result_mean.item() - 0.3362038) < 1e-3 assert abs(result_mean.item() - 0.3362038) < 1e-3
@unittest.skip("Test not supported.")
def test_trained_betas(self): def test_trained_betas(self):
pass pass
@unittest.skip("Test not supported.")
def test_add_noise_device(self): def test_add_noise_device(self):
pass pass
import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -52,5 +54,6 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest): ...@@ -52,5 +54,6 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
for t in [0, 50, 99]: for t in [0, 50, 99]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
@unittest.skip("Test not supported.")
def test_add_noise_device(self): def test_add_noise_device(self):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment