"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "74297f5f9dad140ef209e6f89b38d88839bae6fe"
Commit 36a79531 authored by comfyanonymous's avatar comfyanonymous
Browse files

Greatly improve lowvram sampling speed by getting rid of accelerate.

Let me know if this breaks anything.
parent 261bcbb0
...@@ -283,7 +283,7 @@ class ControlLora(ControlNet): ...@@ -283,7 +283,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict() cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) weight = sd[k]
try: try:
comfy.utils.set_attr(self.control_model, k, weight) comfy.utils.set_attr(self.control_model, k, weight)
except: except:
......
...@@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module): ...@@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:
......
...@@ -218,15 +218,8 @@ if args.force_fp16: ...@@ -218,15 +218,8 @@ if args.force_fp16:
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
...@@ -298,8 +291,20 @@ class LoadedModel: ...@@ -298,8 +291,20 @@ class LoadedModel:
if lowvram_model_memory > 0: if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) mem_counter = 0
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = 0
sd = m.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
self.model_accelerated = True self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
...@@ -309,7 +314,11 @@ class LoadedModel: ...@@ -309,7 +314,11 @@ class LoadedModel:
def model_unload(self): def model_unload(self):
if self.model_accelerated: if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model) for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
...@@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0): ...@@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
else: else:
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory) cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
...@@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO ...@@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO
return True return True
return False return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
...@@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False): ...@@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu(): elif is_intel_xpu():
device_supports_cast = True device_supports_cast = True
non_blocking = True non_blocking = device_supports_non_blocking(device)
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking
if device_supports_cast: if device_supports_cast:
if copy: if copy:
...@@ -742,11 +754,7 @@ def soft_empty_cache(force=False): ...@@ -742,11 +754,7 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key): def resolve_lowvram_weight(weight, model, key): #TODO: remove
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else
......
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
import comfy.model_management
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d): class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm): class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm): class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(s, dims, *args, **kwargs): def conv_nd(s, dims, *args, **kwargs):
if dims == 2: if dims == 2:
...@@ -31,35 +97,19 @@ class disable_weight_init: ...@@ -31,35 +97,19 @@ class disable_weight_init:
else: else:
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias
class manual_cast(disable_weight_init): class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear): class Linear(disable_weight_init.Linear):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(disable_weight_init.Conv2d): class Conv2d(disable_weight_init.Conv2d):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class Conv3d(disable_weight_init.Conv3d): class Conv3d(disable_weight_init.Conv3d):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class GroupNorm(disable_weight_init.GroupNorm): class GroupNorm(disable_weight_init.GroupNorm):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class LayerNorm(disable_weight_init.LayerNorm): class LayerNorm(disable_weight_init.LayerNorm):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
...@@ -4,7 +4,6 @@ einops ...@@ -4,7 +4,6 @@ einops
transformers>=4.25.1 transformers>=4.25.1
safetensors>=0.3.0 safetensors>=0.3.0
aiohttp aiohttp
accelerate
pyyaml pyyaml
Pillow Pillow
scipy scipy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment