Unverified Commit a3efa433 authored by Hamish Friedlander's avatar Hamish Friedlander Committed by GitHub
Browse files

Fix DDIM on Windows not using int64 for timesteps (#819)

parent 728a3f3e
...@@ -149,7 +149,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -149,7 +149,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
""" """
...@@ -192,7 +192,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -192,7 +192,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += offset self.timesteps += offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment