Commit 97ee2306 authored by comfyanonymous's avatar comfyanonymous
Browse files

Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
parent fa1959e3
......@@ -327,12 +327,18 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def text_encoder_device():
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def text_encoder_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type
......@@ -422,10 +428,15 @@ def mps_mode():
global cpu_state
return cpu_state == CPUState.MPS
def should_use_fp16():
def should_use_fp16(device=None):
global xpu_available
global directml_enabled
if device is not None: #TODO
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return False
if FORCE_FP32:
return False
......
......@@ -526,9 +526,10 @@ class CLIP:
tokenizer = target.tokenizer
self.device = model_management.text_encoder_device()
params["device"] = self.device
self.cond_stage_model = clip(**(params))
self.cond_stage_model = self.cond_stage_model.to(self.device)
if model_management.should_use_fp16(self.device):
self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model)
......@@ -559,11 +560,14 @@ class CLIP:
if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx)
try:
self.cond_stage_model.to(self.device)
self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
except Exception as e:
self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
raise e
cond_out = cond
......
......@@ -5,6 +5,8 @@ import comfy.ops
import torch
import traceback
import zipfile
from . import model_management
import contextlib
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
......@@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
......@@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
out_tokens += [tokens_temp]
if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device)
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size
for x in embedding_weights:
......@@ -106,8 +107,18 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(self.device)
tokens = torch.LongTensor(tokens).to(device)
if backup_embeds.weight.dtype != torch.float32:
print("autocast clip")
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
print("no autocast clip")
with precision_scope(model_management.get_autocast_device(device)):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
......@@ -123,7 +134,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z, pooled_output
return z.float(), pooled_output.float()
def encode(self, tokens):
return self(tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment