"tests/vscode:/vscode.git/clone" did not exist on "f26cde3dff6b288b4c6e5c84a287373aa8c8a689"
Unverified Commit 5d550cfd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure that DEIS, DPM and UniPC can correctly be switched in & out (#2595)

* [Schedulers] Correct config changing

* uP

* add tests
parent 24d624a4
......@@ -154,13 +154,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DEIS
if algorithm_type not in ["deis"]:
if algorithm_type in ["dpmsolver", "dpmsolver++"]:
algorithm_type = "deis"
self.register_to_config(algorithm_type="deis")
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["logrho"]:
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
solver_type = "logrho"
self.register_to_config(solver_type="logrho")
else:
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
......
......@@ -165,12 +165,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis":
algorithm_type = "dpmsolver++"
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint"
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
......@@ -164,12 +164,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis":
algorithm_type = "dpmsolver++"
self.register_to_config(algorithm_type="dpmsolver++")
else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint"
self.register_to_config(solver_type="midpoint")
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
......@@ -168,7 +168,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
solver_type = "bh1"
self.register_to_config(solver_type="bh1")
else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
......@@ -953,7 +953,12 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
......@@ -973,6 +978,25 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2791) < 1e-3
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2791) < 1e-3
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
......@@ -1130,10 +1154,11 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
......@@ -1244,6 +1269,25 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2251) < 1e-3
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DPMSolverMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
......@@ -2543,7 +2587,12 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
......@@ -2589,6 +2638,25 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DEISMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
......@@ -2742,7 +2810,12 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
......@@ -2788,6 +2861,25 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment