"docs/source/en/vscode:/vscode.git/clone" did not exist on "5c9dd0af952a92f19a1e672b2a9471ad5674841d"
Unverified Commit 5d550cfd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure that DEIS, DPM and UniPC can correctly be switched in & out (#2595)

* [Schedulers] Correct config changing

* uP

* add tests
parent 24d624a4
...@@ -154,13 +154,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -154,13 +154,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DEIS # settings for DEIS
if algorithm_type not in ["deis"]: if algorithm_type not in ["deis"]:
if algorithm_type in ["dpmsolver", "dpmsolver++"]: if algorithm_type in ["dpmsolver", "dpmsolver++"]:
algorithm_type = "deis" self.register_to_config(algorithm_type="deis")
else: else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["logrho"]: if solver_type not in ["logrho"]:
if solver_type in ["midpoint", "heun", "bh1", "bh2"]: if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
solver_type = "logrho" self.register_to_config(solver_type="logrho")
else: else:
raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
......
...@@ -165,12 +165,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,12 +165,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis": if algorithm_type == "deis":
algorithm_type = "dpmsolver++" self.register_to_config(algorithm_type="dpmsolver++")
else: else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]: if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]: if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint" self.register_to_config(solver_type="midpoint")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
...@@ -164,12 +164,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -164,12 +164,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# settings for DPM-Solver # settings for DPM-Solver
if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
if algorithm_type == "deis": if algorithm_type == "deis":
algorithm_type = "dpmsolver++" self.register_to_config(algorithm_type="dpmsolver++")
else: else:
raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
if solver_type not in ["midpoint", "heun"]: if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]: if solver_type in ["logrho", "bh1", "bh2"]:
solver_type = "midpoint" self.register_to_config(solver_type="midpoint")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
...@@ -168,7 +168,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,7 +168,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if solver_type not in ["bh1", "bh2"]: if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]: if solver_type in ["midpoint", "heun", "logrho"]:
solver_type = "bh1" self.register_to_config(solver_type="bh1")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
......
...@@ -953,7 +953,12 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -953,7 +953,12 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -973,6 +978,25 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest): ...@@ -973,6 +978,25 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
for timesteps in [25, 50, 100, 999, 1000]: for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2791) < 1e-3
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2791) < 1e-3
def test_thresholding(self): def test_thresholding(self):
self.check_over_configs(thresholding=False) self.check_over_configs(thresholding=False)
for order in [1, 2, 3]: for order in [1, 2, 3]:
...@@ -1130,7 +1154,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -1130,7 +1154,8 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -1244,6 +1269,25 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -1244,6 +1269,25 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.2251) < 1e-3 assert abs(result_mean.item() - 0.2251) < 1e-3
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DPMSolverMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
def test_fp16_support(self): def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
...@@ -2543,7 +2587,12 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2543,7 +2587,12 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -2589,6 +2638,25 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2589,6 +2638,25 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = DEISMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.23916) < 1e-3
def test_timesteps(self): def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]: for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
...@@ -2742,7 +2810,12 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2742,7 +2810,12 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config): def full_loop(self, scheduler=None, **config):
if scheduler is None:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
...@@ -2788,6 +2861,25 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -2788,6 +2861,25 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2521) < 1e-3
def test_timesteps(self): def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]: for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment