Commit f30b992b authored by comfyanonymous's avatar comfyanonymous
Browse files

.sigma and .timestep now return tensors on the same device as the input.

parent 488de0b4
...@@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module): ...@@ -65,15 +65,15 @@ class ModelSamplingDiscrete(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:
......
...@@ -56,15 +56,15 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): ...@@ -56,15 +56,15 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep): def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long() low_idx = t.floor().long()
high_idx = t.ceil().long() high_idx = t.ceil().long()
w = t.frac() w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment