Commit d9d8702d authored by comfyanonymous's avatar comfyanonymous
Browse files

percent_to_sigma now returns a float instead of a tensor.

parent 8a451234
...@@ -77,9 +77,9 @@ class ModelSamplingDiscrete(torch.nn.Module): ...@@ -77,9 +77,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:
return torch.tensor(999999999.9) return 999999999.9
if percent >= 1.0: if percent >= 1.0:
return torch.tensor(0.0) return 0.0
percent = 1.0 - percent percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)) return self.sigma(torch.tensor(percent * 999.0)).item()
...@@ -67,11 +67,11 @@ class ModelSamplingDiscreteLCM(torch.nn.Module): ...@@ -67,11 +67,11 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
def percent_to_sigma(self, percent): def percent_to_sigma(self, percent):
if percent <= 0.0: if percent <= 0.0:
return torch.tensor(999999999.9) return 999999999.9
if percent >= 1.0: if percent >= 1.0:
return torch.tensor(0.0) return 0.0
percent = 1.0 - percent percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)) return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas): def rescale_zero_terminal_snr_sigmas(sigmas):
......
...@@ -16,8 +16,8 @@ class PatchModelAddDownscale: ...@@ -16,8 +16,8 @@ class PatchModelAddDownscale:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options): def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number: if transformer_options["block"][1] == block_number:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment