Unverified Commit 35fd84be authored by dg845's avatar dg845 Committed by GitHub
Browse files

Replace hardcoded values in SchedulerCommonTest with properties (#5479)




---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent f2756253
...@@ -36,10 +36,10 @@ from diffusers import ( ...@@ -36,10 +36,10 @@ from diffusers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
UniPCMultistepScheduler, UniPCMultistepScheduler,
VQDiffusionScheduler, VQDiffusionScheduler,
logging,
) )
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import logging
from diffusers.utils.testing_utils import CaptureLogger, torch_device from diffusers.utils.testing_utils import CaptureLogger, torch_device
from ..others.test_utils import TOKEN, USER, is_staging_test from ..others.test_utils import TOKEN, USER, is_staging_test
...@@ -48,6 +48,9 @@ from ..others.test_utils import TOKEN, USER, is_staging_test ...@@ -48,6 +48,9 @@ from ..others.test_utils import TOKEN, USER, is_staging_test
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SchedulerObject(SchedulerMixin, ConfigMixin): class SchedulerObject(SchedulerMixin, ConfigMixin):
config_name = "config.json" config_name = "config.json"
...@@ -253,6 +256,60 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -253,6 +256,60 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = () scheduler_classes = ()
forward_default_kwargs = () forward_default_kwargs = ()
@property
def default_num_inference_steps(self):
return 50
@property
def default_timestep(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
try:
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep = scheduler.timesteps[0]
except NotImplementedError:
logger.warning(
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
f" `default_timestep` will be set to the default value of 1."
)
timestep = 1
return timestep
# NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively,
# default_timestep comes earlier in the timestep schedule than default_timestep_2)
@property
def default_timestep_2(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
try:
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
if len(scheduler.timesteps) >= 2:
timestep_2 = scheduler.timesteps[1]
else:
logger.warning(
f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep"
f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0"
f" will be used."
)
timestep_2 = 0
except NotImplementedError:
logger.warning(
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
f" `default_timestep_2` will be set to the default value of 0."
)
timestep_2 = 0
return timestep_2
@property @property
def dummy_sample(self): def dummy_sample(self):
batch_size = 4 batch_size = 4
...@@ -313,6 +370,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -313,6 +370,7 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
time_step = time_step if time_step is not None else self.default_timestep
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default # TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
...@@ -371,6 +429,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -371,6 +429,7 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs.update(forward_kwargs) kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
time_step = time_step if time_step is not None else self.default_timestep
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
...@@ -411,10 +470,10 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -411,10 +470,10 @@ class SchedulerCommonTest(unittest.TestCase):
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
timestep = 1 timestep = self.default_timestep
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep) timestep = float(timestep)
...@@ -497,10 +556,10 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -497,10 +556,10 @@ class SchedulerCommonTest(unittest.TestCase):
def test_step_shape(self): def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
timestep_0 = 1 timestep_0 = self.default_timestep
timestep_1 = 0 timestep_1 = self.default_timestep_2
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
...@@ -558,9 +617,9 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -558,9 +617,9 @@ class SchedulerCommonTest(unittest.TestCase):
) )
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", 50) num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
timestep = 0 timestep = self.default_timestep
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler: if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
timestep = 1 timestep = 1
...@@ -644,7 +703,7 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -644,7 +703,7 @@ class SchedulerCommonTest(unittest.TestCase):
continue continue
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(100) scheduler.set_timesteps(self.default_num_inference_steps)
sample = self.dummy_sample.to(torch_device) sample = self.dummy_sample.to(torch_device)
if scheduler_class == CMStochasticIterativeScheduler: if scheduler_class == CMStochasticIterativeScheduler:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment