"vscode:/vscode.git/clone" did not exist on "52e48739a57e29ba47c238b2bbf06a391066da57"
Unverified Commit fe794894 authored by Pierre Chapuis's avatar Pierre Chapuis Committed by GitHub
Browse files

allow tensors in several schedulers step() call (#8905)

parent 461efc57
......@@ -674,7 +674,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......@@ -685,7 +685,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
......
......@@ -920,7 +920,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
......
......@@ -787,7 +787,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
......
......@@ -927,7 +927,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
......
......@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
return_dict: bool = True,
......
......@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......
......@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.Tensor,
timestep: int,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment