Commit 87b00b37 authored by comfyanonymous's avatar comfyanonymous
Browse files

Added an experimental VAEDecodeTiled.

This decodes the image with the VAE in tiles which should be faster and
use less vram.

It's in the _for_testing section so I might change/remove it or even
add the functionality to the regular VAEDecode node depending on how
well it performs which means don't depend too much on it.
parent 5796705c
......@@ -318,6 +318,37 @@ class VAE:
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples):
tile_x = tile_y = 64
overlap = 8
model_management.unload_model()
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
self.first_stage_model = self.first_stage_model.to(self.device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device))
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
ps = pixel_samples.cpu()
mask = torch.ones_like(ps)
feather = overlap * 8
for t in range(feather):
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask
out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask
output[b:b+1] = out/out_div
self.first_stage_model = self.first_stage_model.cpu()
return output.movedim(1,-1)
def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device)
......
......@@ -106,6 +106,21 @@ class VAEDecode:
def decode(self, vae, samples):
return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode:
def __init__(self, device="cpu"):
self.device = device
......@@ -789,6 +804,7 @@ NODE_CLASS_MAPPINGS = {
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader,
"VAEDecodeTiled": VAEDecodeTiled,
}
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment