"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "d0670f7c258878e92b32761de66a3997ef097ef8"
Commit f081017c authored by comfyanonymous's avatar comfyanonymous
Browse files

Save memory by storing text encoder weights in fp16 in most situations.

Do inference in fp32 to make sure quality stays the exact same.
parent d7b3b0f8
......@@ -433,7 +433,7 @@ def text_encoder_device():
return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
#NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
if should_use_fp16() or torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
return get_torch_device()
else:
return torch.device("cpu")
......
......@@ -546,11 +546,8 @@ class CLIP:
offload_device = model_management.text_encoder_offload_device()
params['device'] = load_device
self.cond_stage_model = clip(**(params))
#TODO: make sure this doesn't have a quality loss before enabling.
# if model_management.should_use_fp16(load_device):
# self.cond_stage_model.half()
self.cond_stage_model = self.cond_stage_model.to()
if model_management.should_use_fp16(load_device):
self.cond_stage_model.half()
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
......
......@@ -137,9 +137,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if backup_embeds.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device)):
with precision_scope(model_management.get_autocast_device(device), torch.float32):
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
......
......@@ -5,7 +5,6 @@ import { app } from "../../scripts/app.js";
app.registerExtension({
name: "Comfy.UploadImage",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
console.log(nodeData);
if (nodeData?.input?.required?.image?.[1]?.image_upload === true) {
nodeData.input.required.upload = ["IMAGEUPLOAD"];
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment