Unverified Commit 35fd84be authored by dg845's avatar dg845 Committed by GitHub
Browse files

Replace hardcoded values in SchedulerCommonTest with properties (#5479)




---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent f2756253
......@@ -36,10 +36,10 @@ from diffusers import (
LMSDiscreteScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import logging
from diffusers.utils.testing_utils import CaptureLogger, torch_device
from ..others.test_utils import TOKEN, USER, is_staging_test
......@@ -48,6 +48,9 @@ from ..others.test_utils import TOKEN, USER, is_staging_test
torch.backends.cuda.matmul.allow_tf32 = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SchedulerObject(SchedulerMixin, ConfigMixin):
config_name = "config.json"
......@@ -253,6 +256,60 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
@property
def default_num_inference_steps(self):
return 50
@property
def default_timestep(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
try:
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep = scheduler.timesteps[0]
except NotImplementedError:
logger.warning(
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
f" `default_timestep` will be set to the default value of 1."
)
timestep = 1
return timestep
# NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively,
# default_timestep comes earlier in the timestep schedule than default_timestep_2)
@property
def default_timestep_2(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
try:
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
if len(scheduler.timesteps) >= 2:
timestep_2 = scheduler.timesteps[1]
else:
logger.warning(
f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep"
f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0"
f" will be used."
)
timestep_2 = 0
except NotImplementedError:
logger.warning(
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
f" `default_timestep_2` will be set to the default value of 0."
)
timestep_2 = 0
return timestep_2
@property
def dummy_sample(self):
batch_size = 4
......@@ -313,6 +370,7 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
time_step = time_step if time_step is not None else self.default_timestep
for scheduler_class in self.scheduler_classes:
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
......@@ -371,6 +429,7 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
time_step = time_step if time_step is not None else self.default_timestep
for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
......@@ -411,10 +470,10 @@ class SchedulerCommonTest(unittest.TestCase):
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
for scheduler_class in self.scheduler_classes:
timestep = 1
timestep = self.default_timestep
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
......@@ -497,10 +556,10 @@ class SchedulerCommonTest(unittest.TestCase):
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
timestep_0 = 1
timestep_1 = 0
timestep_0 = self.default_timestep
timestep_1 = self.default_timestep_2
for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
......@@ -558,9 +617,9 @@ class SchedulerCommonTest(unittest.TestCase):
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", 50)
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
timestep = 0
timestep = self.default_timestep
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
timestep = 1
......@@ -644,7 +703,7 @@ class SchedulerCommonTest(unittest.TestCase):
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(100)
scheduler.set_timesteps(self.default_num_inference_steps)
sample = self.dummy_sample.to(torch_device)
if scheduler_class == CMStochasticIterativeScheduler:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment