Commit 2b6b1781 authored by MoonRide303's avatar MoonRide303
Browse files

Added support for lanczos scaling

parent 6d3dee9d
import torch import torch
import torchvision
import math import math
import struct import struct
import comfy.checkpoint_pickle import comfy.checkpoint_pickle
import safetensors.torch import safetensors.torch
from PIL import Image
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
...@@ -346,6 +348,13 @@ def bislerp(samples, width, height): ...@@ -346,6 +348,13 @@ def bislerp(samples, width, height):
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result return result
def lanczos(samples, width, height):
images = [torchvision.transforms.functional.to_pil_image(image) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torchvision.transforms.functional.to_tensor(image) for image in images]
result = torch.stack(images)
return result
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
old_width = samples.shape[3] old_width = samples.shape[3]
...@@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop): ...@@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop):
if upscale_method == "bislerp": if upscale_method == "bislerp":
return bislerp(s, width, height) return bislerp(s, width, height)
elif upscale_method == "lanczos":
return lanczos(s, width, height)
else: else:
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
......
...@@ -211,7 +211,7 @@ class Sharpen: ...@@ -211,7 +211,7 @@ class Sharpen:
return (result,) return (result,)
class ImageScaleToTotalPixels: class ImageScaleToTotalPixels:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@classmethod @classmethod
......
...@@ -1423,7 +1423,7 @@ class LoadImageMask: ...@@ -1423,7 +1423,7 @@ class LoadImageMask:
return True return True
class ImageScale: class ImageScale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
crop_methods = ["disabled", "center"] crop_methods = ["disabled", "center"]
@classmethod @classmethod
...@@ -1444,7 +1444,7 @@ class ImageScale: ...@@ -1444,7 +1444,7 @@ class ImageScale:
return (s,) return (s,)
class ImageScaleBy: class ImageScaleBy:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment