"vscode:/vscode.git/clone" did not exist on "0e82fd3df4b2f36e3352b72ec2a441d9efed3a5f"
Commit 36a79531 authored by comfyanonymous's avatar comfyanonymous
Browse files

Greatly improve lowvram sampling speed by getting rid of accelerate.

Let me know if this breaks anything.
parent 261bcbb0
......@@ -283,7 +283,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict()
for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
weight = sd[k]
try:
comfy.utils.set_attr(self.control_model, k, weight)
except:
......
......@@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
......
......@@ -218,15 +218,8 @@ if args.force_fp16:
FORCE_FP16 = True
if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if cpu_state != CPUState.GPU:
......@@ -298,8 +291,20 @@ class LoadedModel:
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = 0
sd = m.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize:
......@@ -309,7 +314,11 @@ class LoadedModel:
def model_unload(self):
if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model)
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
......@@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
......@@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
......@@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True
non_blocking = True
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking
non_blocking = device_supports_non_blocking(device)
if device_supports_cast:
if copy:
......@@ -742,11 +754,7 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight
#TODO: might be cleaner to put this somewhere else
......
import torch
from contextlib import contextmanager
import comfy.model_management
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
class disable_weight_init:
class Linear(torch.nn.Linear):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
......@@ -31,35 +97,19 @@ class disable_weight_init:
else:
raise ValueError(f"unsupported dimensions: {dims}")
def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
comfy_cast_weights = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment