"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3cfe187dc70cd902d717ffa010c5f2c7a3956f0e"
Commit 97ee2306 authored by comfyanonymous's avatar comfyanonymous
Browse files

Make highvram and normalvram shift the text encoders to vram and back.

This is faster on big text encoder models than running it on the CPU.
parent fa1959e3
...@@ -327,12 +327,18 @@ def unload_if_low_vram(model): ...@@ -327,12 +327,18 @@ def unload_if_low_vram(model):
return model.cpu() return model.cpu()
return model return model
def text_encoder_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_device():
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type
...@@ -422,10 +428,15 @@ def mps_mode(): ...@@ -422,10 +428,15 @@ def mps_mode():
global cpu_state global cpu_state
return cpu_state == CPUState.MPS return cpu_state == CPUState.MPS
def should_use_fp16(): def should_use_fp16(device=None):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
if device is not None: #TODO
if hasattr(device, 'type'):
if (device.type == 'cpu' or device.type == 'mps'):
return False
if FORCE_FP32: if FORCE_FP32:
return False return False
......
...@@ -526,9 +526,10 @@ class CLIP: ...@@ -526,9 +526,10 @@ class CLIP:
tokenizer = target.tokenizer tokenizer = target.tokenizer
self.device = model_management.text_encoder_device() self.device = model_management.text_encoder_device()
params["device"] = self.device
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
self.cond_stage_model = self.cond_stage_model.to(self.device) if model_management.should_use_fp16(self.device):
self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device())
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model) self.patcher = ModelPatcher(self.cond_stage_model)
...@@ -559,11 +560,14 @@ class CLIP: ...@@ -559,11 +560,14 @@ class CLIP:
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try: try:
self.cond_stage_model.to(self.device)
self.patch_model() self.patch_model()
cond, pooled = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
self.unpatch_model() self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
except Exception as e: except Exception as e:
self.unpatch_model() self.unpatch_model()
self.cond_stage_model.to(model_management.text_encoder_offload_device())
raise e raise e
cond_out = cond cond_out = cond
......
...@@ -5,6 +5,8 @@ import comfy.ops ...@@ -5,6 +5,8 @@ import comfy.ops
import torch import torch
import traceback import traceback
import zipfile import zipfile
from . import model_management
import contextlib
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
...@@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = CLIPTextModel(config)
self.device = device
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
...@@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
if len(embedding_weights) > 0: if len(embedding_weights) > 0:
new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device) new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:] new_embedding.weight[:token_dict_size] = current_embeds.weight[:]
n = token_dict_size n = token_dict_size
for x in embedding_weights: for x in embedding_weights:
...@@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): ...@@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def forward(self, tokens): def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings() backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(self.device) tokens = torch.LongTensor(tokens).to(device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds) if backup_embeds.weight.dtype != torch.float32:
print("autocast clip")
if self.layer == "last": precision_scope = torch.autocast
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else: else:
z = outputs.hidden_states[self.layer_idx] precision_scope = contextlib.nullcontext
if self.layer_norm_hidden_state: print("no autocast clip")
z = self.transformer.text_model.final_layer_norm(z)
with precision_scope(model_management.get_autocast_device(device)):
pooled_output = outputs.pooler_output outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
if self.text_projection is not None: self.transformer.set_input_embeddings(backup_embeds)
pooled_output = pooled_output @ self.text_projection
return z, pooled_output if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection
return z.float(), pooled_output.float()
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment