Commit 87b00b37 authored by comfyanonymous's avatar comfyanonymous
Browse files

Added an experimental VAEDecodeTiled.

This decodes the image with the VAE in tiles which should be faster and
use less vram.

It's in the _for_testing section so I might change/remove it or even
add the functionality to the regular VAEDecode node depending on how
well it performs which means don't depend too much on it.
parent 5796705c
...@@ -318,6 +318,37 @@ class VAE: ...@@ -318,6 +318,37 @@ class VAE:
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples):
tile_x = tile_y = 64
overlap = 8
model_management.unload_model()
output = torch.empty((samples.shape[0], 3, samples.shape[2] * 8, samples.shape[3] * 8), device="cpu")
self.first_stage_model = self.first_stage_model.to(self.device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
out_div = torch.zeros((s.shape[0], 3, s.shape[2] * 8, s.shape[3] * 8), device="cpu")
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * s_in.to(self.device))
pixel_samples = torch.clamp((pixel_samples + 1.0) / 2.0, min=0.0, max=1.0)
ps = pixel_samples.cpu()
mask = torch.ones_like(ps)
feather = overlap * 8
for t in range(feather):
mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
out[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += ps * mask
out_div[:,:,y*8:(y+tile_y)*8,x*8:(x+tile_x)*8] += mask
output[b:b+1] = out/out_div
self.first_stage_model = self.first_stage_model.cpu()
return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
model_management.unload_model() model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
......
...@@ -106,6 +106,21 @@ class VAEDecode: ...@@ -106,6 +106,21 @@ class VAEDecode:
def decode(self, vae, samples): def decode(self, vae, samples):
return (vae.decode(samples["samples"]), ) return (vae.decode(samples["samples"]), )
class VAEDecodeTiled:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
def decode(self, vae, samples):
return (vae.decode_tiled(samples["samples"]), )
class VAEEncode: class VAEEncode:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
self.device = device self.device = device
...@@ -789,6 +804,7 @@ NODE_CLASS_MAPPINGS = { ...@@ -789,6 +804,7 @@ NODE_CLASS_MAPPINGS = {
"ControlNetApply": ControlNetApply, "ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader, "ControlNetLoader": ControlNetLoader,
"DiffControlNetLoader": DiffControlNetLoader, "DiffControlNetLoader": DiffControlNetLoader,
"VAEDecodeTiled": VAEDecodeTiled,
} }
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment