Unverified Commit fe794894 authored by Pierre Chapuis's avatar Pierre Chapuis Committed by GitHub
Browse files

allow tensors in several schedulers step() call (#8905)

parent 461efc57
...@@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
...@@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
model_output (`torch.Tensor`): model_output (`torch.Tensor`):
The direct output from learned diffusion model. The direct output from learned diffusion model.
timestep (`float`): timestep (`int`):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
......
...@@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
generator=None, generator=None,
variance_noise: Optional[torch.Tensor] = None, variance_noise: Optional[torch.Tensor] = None,
......
...@@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
generator=None, generator=None,
variance_noise: Optional[torch.Tensor] = None, variance_noise: Optional[torch.Tensor] = None,
......
...@@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
......
...@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
......
...@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
......
...@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def step( def step(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
timestep: int, timestep: Union[int, torch.Tensor],
sample: torch.Tensor, sample: torch.Tensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment