Commit af291e6f authored by comfyanonymous's avatar comfyanonymous
Browse files

Convert line endings to unix.

parent cadef9ff
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
import comfy.utils import comfy.utils
class Blend: class Blend:
def __init__(self): def __init__(self):
pass pass
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image1": ("IMAGE",), "image1": ("IMAGE",),
"image2": ("IMAGE",), "image2": ("IMAGE",),
"blend_factor": ("FLOAT", { "blend_factor": ("FLOAT", {
"default": 0.5, "default": 0.5,
"min": 0.0, "min": 0.0,
"max": 1.0, "max": 1.0,
"step": 0.01 "step": 0.01
}), }),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
}, },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "blend_images" FUNCTION = "blend_images"
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape: if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2) image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
image2 = image2.permute(0, 2, 3, 1) image2 = image2.permute(0, 2, 3, 1)
blended_image = self.blend_mode(image1, image2, blend_mode) blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1) blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,) return (blended_image,)
def blend_mode(self, img1, img2, mode): def blend_mode(self, img1, img2, mode):
if mode == "normal": if mode == "normal":
return img2 return img2
elif mode == "multiply": elif mode == "multiply":
return img1 * img2 return img1 * img2
elif mode == "screen": elif mode == "screen":
return 1 - (1 - img1) * (1 - img2) return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay": elif mode == "overlay":
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light": elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
else: else:
raise ValueError(f"Unsupported blend mode: {mode}") raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x): def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
class Blur: class Blur:
def __init__(self): def __init__(self):
pass pass
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"blur_radius": ("INT", { "blur_radius": ("INT", {
"default": 1, "default": 1,
"min": 1, "min": 1,
"max": 31, "max": 31,
"step": 1 "step": 1
}), }),
"sigma": ("FLOAT", { "sigma": ("FLOAT", {
"default": 1.0, "default": 1.0,
"min": 0.1, "min": 0.1,
"max": 10.0, "max": 10.0,
"step": 0.1 "step": 0.1
}), }),
}, },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "blur" FUNCTION = "blur"
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float): def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
d = torch.sqrt(x * x + y * y) d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum() return g / g.sum()
def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
if blur_radius == 0: if blur_radius == 0:
return (image,) return (image,)
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = blur_radius * 2 + 1 kernel_size = blur_radius * 2 + 1
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
blurred = blurred.permute(0, 2, 3, 1) blurred = blurred.permute(0, 2, 3, 1)
return (blurred,) return (blurred,)
class Quantize: class Quantize:
def __init__(self): def __init__(self):
pass pass
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"colors": ("INT", { "colors": ("INT", {
"default": 256, "default": 256,
"min": 1, "min": 1,
"max": 256, "max": 256,
"step": 1 "step": 1
}), }),
"dither": (["none", "floyd-steinberg"],), "dither": (["none", "floyd-steinberg"],),
}, },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "quantize" FUNCTION = "quantize"
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape batch_size, height, width, _ = image.shape
result = torch.zeros_like(image) result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size): for b in range(batch_size):
tensor_image = image[b] tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy() img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB') pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array result[b] = quantized_array
return (result,) return (result,)
class Sharpen: class Sharpen:
def __init__(self): def __init__(self):
pass pass
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
"required": { "required": {
"image": ("IMAGE",), "image": ("IMAGE",),
"sharpen_radius": ("INT", { "sharpen_radius": ("INT", {
"default": 1, "default": 1,
"min": 1, "min": 1,
"max": 31, "max": 31,
"step": 1 "step": 1
}), }),
"alpha": ("FLOAT", { "alpha": ("FLOAT", {
"default": 1.0, "default": 1.0,
"min": 0.1, "min": 0.1,
"max": 5.0, "max": 5.0,
"step": 0.1 "step": 0.1
}), }),
}, },
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "sharpen" FUNCTION = "sharpen"
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float):
if sharpen_radius == 0: if sharpen_radius == 0:
return (image,) return (image,)
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1 kernel_size = sharpen_radius * 2 + 1
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
center = kernel_size // 2 center = kernel_size // 2
kernel[center, center] = kernel_size**2 kernel[center, center] = kernel_size**2
kernel *= alpha kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = sharpened.permute(0, 2, 3, 1) sharpened = sharpened.permute(0, 2, 3, 1)
result = torch.clamp(sharpened, 0, 1) result = torch.clamp(sharpened, 0, 1)
return (result,) return (result,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Blend": Blend, "Blend": Blend,
"Blur": Blur, "Blur": Blur,
"Quantize": Quantize, "Quantize": Quantize,
"Sharpen": Sharpen, "Sharpen": Sharpen,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment