Unverified Commit 5d550cfd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure that DEIS, DPM and UniPC can correctly be switched in & out (#2595)

* [Schedulers] Correct config changing

* uP

* add tests
parent 24d624a4
...@@ -154,13 +154,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -154,13 +154,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DEIS # settings for DEIS
if algorithm_type not in ["deis"]: if algorithm_type not in ["deis"]:
if algorithm_type in ["dpmsolver", "dpmsolver++"]: if algorithm_type in ["dpmsolver", "dpmsolver++"]:
algorithm_type = "deis" self.register_to_config(algorithm_type="deis")
else: else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["logrho"]: if solver_type not in ["logrho"]:
if solver_type in ["midpoint", "heun", "bh1", "bh2"]: if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
solver_type = "logrho" self.register_to_config(solver_type="logrho")
else: else:
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
......
...@@ -165,12 +165,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,12 +165,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis": if algorithm_type == "deis":
algorithm_type = "dpmsolver++" self.register_to_config(algorithm_type="dpmsolver++")
else: else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]: if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]: if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint" self.register_to_config(solver_type="midpoint")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
...@@ -164,12 +164,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -164,12 +164,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis": if algorithm_type == "deis":
algorithm_type = "dpmsolver++" self.register_to_config(algorithm_type="dpmsolver++")
else: else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]: if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]: if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint" self.register_to_config(solver_type="midpoint")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
...@@ -168,7 +168,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,7 +168,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if solver_type not in ["bh1", "bh2"]: if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]: if solver_type in ["midpoint", "heun", "logrho"]:
solver_type = "bh1" self.register_to_config(solver_type="bh1")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
...@@ -953,7 +953,12 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -953,7 +953,12 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -973,6 +978,25 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -973,6 +978,25 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
for timesteps in [25, 50, 100, 999, 1000]: for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2791) < 1e-3
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2791) < 1e-3
def test_thresholding(self): def test_thresholding(self):
self.check_over_configs(thresholding=False) self.check_over_configs(thresholding=False)
for order in [1, 2, 3]: for order in [1, 2, 3]:
...@@ -1130,7 +1154,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -1130,7 +1154,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -1244,6 +1269,25 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -1244,6 +1269,25 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2251) < 1e-3 assert abs(result_mean.item() - 0.2251) < 1e-3
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DPMSolverMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
def test_fp16_support(self): def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
...@@ -2543,7 +2587,12 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2543,7 +2587,12 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -2589,6 +2638,25 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2589,6 +2638,25 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DEISMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
def test_timesteps(self): def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]: for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
...@@ -2742,7 +2810,12 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2742,7 +2810,12 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -2788,6 +2861,25 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2788,6 +2861,25 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
def test_timesteps(self): def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]: for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment