Unverified Commit 351b37ea authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Fix UniPC tests and remove some test warnings (#2396)

* Change solver_type to match the previous tests.

* Prevent warnings about scale_model_inputs

* Prevent console log about division by zero.
parent 2e0d489a
......@@ -287,6 +287,11 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# Make sure `scale_model_input` is invoked to prevent a warning
if scheduler_class != VQDiffusionScheduler:
_ = scheduler.scale_model_input(sample, 0)
_ = new_scheduler.scale_model_input(sample, 0)
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
kwargs["generator"] = torch.manual_seed(0)
......@@ -597,7 +602,7 @@ class SchedulerCommonTest(unittest.TestCase):
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.1, 0.3]))
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
......@@ -2648,6 +2653,7 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
"solver_type": "bh1",
}
config.update(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment