"...composable_kernel_onnx.git" did not exist on "9d99a5807298c3f263d39a08328c3c68c930a900"
Commit 89a0767a authored by comfyanonymous's avatar comfyanonymous
Browse files

Smarter memory management.

Try to keep models on the vram when possible.

Better lowvram mode for controlnets.
parent 2c97c302
...@@ -244,30 +244,15 @@ class Gligen(nn.Module): ...@@ -244,30 +244,15 @@ class Gligen(nn.Module):
self.position_net = position_net self.position_net = position_net
self.key_dim = key_dim self.key_dim = key_dim
self.max_objs = 30 self.max_objs = 30
self.lowvram = False self.current_device = torch.device("cpu")
def _set_position(self, boxes, masks, positive_embeddings): def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings) objs = self.position_net(boxes, masks, positive_embeddings)
def func(x, extra_options):
if self.lowvram == True: key = extra_options["transformer_index"]
self.position_net.cpu() module = self.module_list[key]
def func_lowvram(x, extra_options): return module(x, objs)
key = extra_options["transformer_index"] return func
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device): def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape batch, c, h, w = latent_image_shape
...@@ -312,14 +297,6 @@ class Gligen(nn.Module): ...@@ -312,14 +297,6 @@ class Gligen(nn.Module):
masks.to(device), masks.to(device),
conds.to(device)) conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self):
self.lowvram = False
def get_models(self):
return [self]
def load_gligen(sd): def load_gligen(sd):
sd_k = sd.keys() sd_k = sd.keys()
......
...@@ -2,6 +2,7 @@ import psutil ...@@ -2,6 +2,7 @@ import psutil
from enum import Enum from enum import Enum
from comfy.cli_args import args from comfy.cli_args import args
import torch import torch
import sys
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
...@@ -221,132 +222,161 @@ except: ...@@ -221,132 +222,161 @@ except:
print("Could not pick default device.") print("Could not pick default device.")
current_loaded_model = None current_loaded_models = []
current_gpu_controlnets = []
model_accelerated = False class LoadedModel:
def __init__(self, model):
self.model = model
self.model_accelerated = False
self.device = model.load_device
def model_memory(self):
return self.model.model_size()
def unload_model(): def model_memory_required(self, device):
global current_loaded_model if device == self.model.current_device:
global model_accelerated return 0
global current_gpu_controlnets else:
global vram_state return self.model_memory()
if current_loaded_model is not None: def model_load(self, lowvram_model_memory=0):
if model_accelerated: patch_model_to = None
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) if lowvram_model_memory == 0:
model_accelerated = False patch_model_to = self.device
current_loaded_model.unpatch_model() self.model.model_patches_to(self.device)
current_loaded_model.model.to(current_loaded_model.offload_device) self.model.model_patches_to(self.model.model_dtype())
current_loaded_model.model_patches_to(current_loaded_model.offload_device)
current_loaded_model = None
if vram_state != VRAMState.HIGH_VRAM:
soft_empty_cache()
if vram_state != VRAMState.HIGH_VRAM: try:
if len(current_gpu_controlnets) > 0: self.real_model = self.model.patch_model(device_to=patch_model_to) #TODO: do something with loras and offloading to CPU
for n in current_gpu_controlnets: except Exception as e:
n.cpu() self.model.unpatch_model(self.model.offload_device)
current_gpu_controlnets = [] self.model_unload()
raise e
def minimum_inference_memory(): if lowvram_model_memory > 0:
return (768 * 1024 * 1024) print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
self.model_accelerated = True
def load_model_gpu(model): return self.real_model
global current_loaded_model
global vram_state
global model_accelerated
if model is current_loaded_model: def model_unload(self):
return if self.model_accelerated:
unload_model() accelerate.hooks.remove_hook_from_submodules(self.real_model)
self.model_accelerated = False
torch_dev = model.load_device self.model.unpatch_model(self.model.offload_device)
model.model_patches_to(torch_dev) self.model.model_patches_to(self.model.offload_device)
model.model_patches_to(model.model_dtype())
current_loaded_model = model
if is_device_cpu(torch_dev): def __eq__(self, other):
vram_set_state = VRAMState.DISABLED return self.model is other.model
else:
vram_set_state = vram_state
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = model.model_size()
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
real_model = model.model
patch_model_to = None
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
model_accelerated = False
patch_model_to = torch_dev
try: def minimum_inference_memory():
real_model = model.patch_model(device_to=patch_model_to) return (1024 * 1024 * 1024)
except Exception as e:
model.unpatch_model() def unload_model_clones(model):
unload_model() to_unload = []
raise e for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
if patch_model_to is not None: to_unload = [i] + to_unload
real_model.to(torch_dev)
for i in to_unload:
if vram_set_state == VRAMState.NO_VRAM: print("unload clone", i)
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) current_loaded_models.pop(i).model_unload()
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True def free_memory(memory_required, device, keep_loaded=[]):
elif vram_set_state == VRAMState.LOW_VRAM: unloaded_model = False
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) for i in range(len(current_loaded_models) -1, -1, -1):
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) current_free_mem = get_free_memory(device)
model_accelerated = True if current_free_mem > memory_required:
break
return current_loaded_model shift_model = current_loaded_models[i]
if shift_model.device == device:
def load_controlnet_gpu(control_models): if shift_model not in keep_loaded:
global current_gpu_controlnets current_loaded_models.pop(i).model_unload()
unloaded_model = True
if unloaded_model:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
global vram_state global vram_state
if vram_state == VRAMState.DISABLED:
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: inference_memory = minimum_inference_memory()
for m in control_models: extra_mem = max(inference_memory, memory_required)
if hasattr(m, 'set_lowvram'):
m.set_lowvram(True) models_to_load = []
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
if loaded_model in current_loaded_models:
index = current_loaded_models.index(loaded_model)
current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model)
else:
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
return return
models = [] print("loading new")
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets: total_memory_required = {}
if m not in models: for loaded_model in models_to_load:
m.cpu() unload_model_clones(loaded_model.model)
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
device = get_torch_device() for device in total_memory_required:
current_gpu_controlnets = [] if device != torch.device("cpu"):
for m in models: free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
current_gpu_controlnets.append(m.to(device))
for loaded_model in models_to_load:
model = loaded_model.model
torch_dev = model.load_device
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
def load_if_low_vram(model): if vram_set_state == VRAMState.NO_VRAM:
global vram_state lowvram_model_memory = 256 * 1024 * 1024
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.to(get_torch_device())
return model
def unload_if_low_vram(model): cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
global vram_state current_loaded_models.insert(0, loaded_model)
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: return
return model.cpu()
return model
def load_model_gpu(model):
return load_models_gpu([model])
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
print(sys.getrefcount(current_loaded_models[i].model))
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)
x.model_unload()
del x
def unet_offload_device(): def unet_offload_device():
if vram_state == VRAMState.HIGH_VRAM: if vram_state == VRAMState.HIGH_VRAM:
...@@ -354,6 +384,21 @@ def unet_offload_device(): ...@@ -354,6 +384,21 @@ def unet_offload_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM:
return torch_dev
cpu_dev = torch.device("cpu")
model_size = dtype.itemsize * parameters
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
...@@ -456,6 +501,13 @@ def get_free_memory(dev=None, torch_free_too=False): ...@@ -456,6 +501,13 @@ def get_free_memory(dev=None, torch_free_too=False):
else: else:
return mem_free_total return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area(): def maximum_batch_area():
global vram_state global vram_state
if vram_state == VRAMState.NO_VRAM: if vram_state == VRAMState.NO_VRAM:
......
...@@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type): ...@@ -51,19 +51,24 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]] models += [c[1][model_type]]
return models return models
def load_additional_models(positive, negative, dtype): def get_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning""" """loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
control_models = []
for m in control_nets:
control_models += m.get_models()
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1].to(dtype) for x in gligen] gligen = [x[1] for x in gligen]
models = control_nets + gligen models = control_models + gligen
comfy.model_management.load_controlnet_gpu(models)
return models return models
def cleanup_additional_models(models): def cleanup_additional_models(models):
"""cleanup additional models that were loaded""" """cleanup additional models that were loaded"""
for m in models: for m in models:
m.cleanup() if hasattr(m, 'cleanup'):
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
...@@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative ...@@ -72,7 +77,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask = prepare_mask(noise_mask, noise.shape, device) noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None real_model = None
comfy.model_management.load_model_gpu(model) models = get_additional_models(positive, negative)
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[2] * noise.shape[3]))
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)
...@@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative ...@@ -81,7 +87,6 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy = broadcast_cond(positive, noise.shape[0], device) positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device) negative_copy = broadcast_cond(negative, noise.shape[0], device)
models = load_additional_models(positive, negative, model.model_dtype())
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
......
...@@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -88,9 +88,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
gligen_type = gligen[0] gligen_type = gligen[0]
gligen_model = gligen[1] gligen_model = gligen[1]
if gligen_type == "position": if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else: else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch] patches['middle_patch'] = [gligen_patch]
......
...@@ -244,7 +244,7 @@ def set_attr(obj, attr, value): ...@@ -244,7 +244,7 @@ def set_attr(obj, attr, value):
del prev del prev
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0): def __init__(self, model, load_device, offload_device, size=0, current_device=None):
self.size = size self.size = size
self.model = model self.model = model
self.patches = {} self.patches = {}
...@@ -253,6 +253,10 @@ class ModelPatcher: ...@@ -253,6 +253,10 @@ class ModelPatcher:
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
...@@ -267,7 +271,7 @@ class ModelPatcher: ...@@ -267,7 +271,7 @@ class ModelPatcher:
return size return size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
...@@ -276,6 +280,11 @@ class ModelPatcher: ...@@ -276,6 +280,11 @@ class ModelPatcher:
n.model_keys = self.model_keys n.model_keys = self.model_keys
return n return n
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
...@@ -390,6 +399,11 @@ class ModelPatcher: ...@@ -390,6 +399,11 @@ class ModelPatcher:
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight) set_attr(self.model, key, out_weight)
del temp_weight del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
...@@ -482,7 +496,7 @@ class ModelPatcher: ...@@ -482,7 +496,7 @@ class ModelPatcher:
return weight return weight
def unpatch_model(self): def unpatch_model(self, device_to=None):
keys = list(self.backup.keys()) keys = list(self.backup.keys())
for k in keys: for k in keys:
...@@ -490,6 +504,11 @@ class ModelPatcher: ...@@ -490,6 +504,11 @@ class ModelPatcher:
self.backup = {} self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = model_lora_keys_unet(model.model) key_map = model_lora_keys_unet(model.model)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
...@@ -630,11 +649,12 @@ class VAE: ...@@ -630,11 +649,12 @@ class VAE:
return samples return samples
def decode(self, samples_in): def decode(self, samples_in):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.4
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2562 * samples_in.shape[2] * samples_in.shape[3] * 64)) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
...@@ -650,19 +670,19 @@ class VAE: ...@@ -650,19 +670,19 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap) output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device) self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1) return output.movedim(1,-1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.4 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
...@@ -677,7 +697,6 @@ class VAE: ...@@ -677,7 +697,6 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
model_management.unload_model()
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
...@@ -757,6 +776,7 @@ class ControlNet(ControlBase): ...@@ -757,6 +776,7 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
...@@ -786,11 +806,9 @@ class ControlNet(ControlBase): ...@@ -786,11 +806,9 @@ class ControlNet(ControlBase):
precision_scope = contextlib.nullcontext precision_scope = contextlib.nullcontext
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
self.control_model = model_management.load_if_low_vram(self.control_model)
context = torch.cat(cond['c_crossattn'], 1) context = torch.cat(cond['c_crossattn'], 1)
y = cond.get('c_adm', None) y = cond.get('c_adm', None)
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
self.control_model = model_management.unload_if_low_vram(self.control_model)
out = {'middle':[], 'output': []} out = {'middle':[], 'output': []}
autocast_enabled = torch.is_autocast_enabled() autocast_enabled = torch.is_autocast_enabled()
...@@ -825,7 +843,7 @@ class ControlNet(ControlBase): ...@@ -825,7 +843,7 @@ class ControlNet(ControlBase):
def get_models(self): def get_models(self):
out = super().get_models() out = super().get_models()
out.append(self.control_model) out.append(self.control_model_wrapped)
return out return out
...@@ -1004,7 +1022,6 @@ class T2IAdapter(ControlBase): ...@@ -1004,7 +1022,6 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
keys = t2i_data.keys() keys = t2i_data.keys()
if 'adapter' in keys: if 'adapter' in keys:
...@@ -1090,7 +1107,7 @@ def load_gligen(ckpt_path): ...@@ -1090,7 +1107,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data) model = gligen.load_gligen(data)
if model_management.should_use_fp16(): if model_management.should_use_fp16():
model = model.half() model = model.half()
return model return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually #TODO: this function is a mess and should be removed eventually
...@@ -1202,8 +1219,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -1202,8 +1219,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
...@@ -1224,7 +1246,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o ...@@ -1224,7 +1246,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys:", left_over) print("left over keys:", left_over)
return (ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae, clipvision) model_patcher = ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format def load_unet(unet_path): #load unet in diffusers format
......
...@@ -354,6 +354,7 @@ class PromptExecutor: ...@@ -354,6 +354,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
comfy.model_management.cleanup_models()
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
executed = set() executed = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment