Unverified Commit 6e123688 authored by Justin Ruan's avatar Justin Ruan Committed by GitHub
Browse files

Remove unused parameters and fixed `FutureWarning` (#6317)

* Remove unused parameters and fixed `FutureWarning`

* Fixed wrong config instance

* update unittest for `DDIMInverseScheduler`
parent f0a588b8
...@@ -293,9 +293,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -293,9 +293,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: torch.FloatTensor, sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]: ) -> Union[DDIMSchedulerOutput, Tuple]:
""" """
...@@ -332,7 +329,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -332,7 +329,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t+1) # 1. get previous step value (=t+1)
prev_timestep = timestep prev_timestep = timestep
timestep = min( timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1 timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
) )
# 2. compute alphas, betas # 2. compute alphas, betas
......
...@@ -7,7 +7,7 @@ from .test_schedulers import SchedulerCommonTest ...@@ -7,7 +7,7 @@ from .test_schedulers import SchedulerCommonTest
class DDIMInverseSchedulerTest(SchedulerCommonTest): class DDIMInverseSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDIMInverseScheduler,) scheduler_classes = (DDIMInverseScheduler,)
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50)) forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs): def get_scheduler_config(self, **kwargs):
config = { config = {
...@@ -26,7 +26,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -26,7 +26,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0 num_inference_steps = 10
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter sample = self.dummy_sample_deter
...@@ -35,7 +35,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest): ...@@ -35,7 +35,7 @@ class DDIMInverseSchedulerTest(SchedulerCommonTest):
for t in scheduler.timesteps: for t in scheduler.timesteps:
residual = model(sample, t) residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
return sample return sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment