Unverified Commit 3d2f60b3 authored by m957ymj75urz's avatar m957ymj75urz Committed by GitHub
Browse files

Merge branch 'master' into save-images

parents 6daf9bb2 54593db6
......@@ -477,9 +477,9 @@ class UNetModel(nn.Module):
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
# from omegaconf.listconfig import ListConfig
# if type(context_dim) == ListConfig:
# context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
......
......@@ -31,8 +31,25 @@ try:
except:
pass
if "--cpu" in sys.argv:
vram_state = CPU
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
else:
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILBLE = False
if "--lowvram" in sys.argv:
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
......@@ -54,6 +71,8 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb))
if "--cpu" in sys.argv:
vram_state = CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state])
......@@ -159,6 +178,14 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def xformers_enabled():
if vram_state == CPU:
return False
return XFORMERS_IS_AVAILBLE
def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
......
......@@ -6,7 +6,7 @@ import sd2_clip
import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf
import yaml
from .cldm import cldm
from .t2i_adapter import adapter
......@@ -726,12 +726,19 @@ def load_clip(ckpt_path, embedding_directory=None):
return clip
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
config = OmegaConf.load(config_path)
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
fp16 = False
if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]:
if "use_fp16" in model_config_params["unet_config"]["params"]:
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
clip = None
vae = None
......@@ -750,9 +757,13 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config.model)
model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae)
......@@ -853,4 +864,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
model = instantiate_from_config(model_config)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae)
import os
import sys
import shutil
import threading
import asyncio
......@@ -8,9 +9,6 @@ if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import execution
import server
if __name__ == "__main__":
if '--help' in sys.argv:
print("Valid Command line Arguments:")
......@@ -18,6 +16,8 @@ if __name__ == "__main__":
print("\t--port 8188\t\t\tSet the listen port.")
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
......@@ -31,6 +31,9 @@ if __name__ == "__main__":
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"
import execution
import server
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
......@@ -38,8 +41,8 @@ def prompt_worker(q, server):
e.execute(item[-2], item[-1])
q.task_done(item_id, e.outputs)
async def run(server, address='', port=8188, verbose=True):
await asyncio.gather(server.start(address, port, verbose), server.publish_loop())
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server):
from tqdm.auto import tqdm
......@@ -51,7 +54,14 @@ def hijack_progress(server):
return v
setattr(tqdm, "update", wrapped_func)
def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
cleanup_temp()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
......@@ -76,11 +86,22 @@ if __name__ == "__main__":
except:
pass
if '--quick-test-for-ci' in sys.argv:
exit(0)
call_on_start = None
if "--windows-standalone-build" in sys.argv:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))
call_on_start = startup_server
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
cleanup_temp()
......@@ -775,6 +775,7 @@ class KSamplerAdvanced:
class SaveImage:
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
self.url_suffix = ""
@classmethod
def INPUT_TYPES(s):
......@@ -818,6 +819,9 @@ class SaveImage:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
paths = list()
for image in images:
i = 255. * image.cpu().numpy()
......@@ -828,12 +832,25 @@ class SaveImage:
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, optimize=True)
paths.append(os.path.join(subfolder, file))
paths.append(os.path.join(subfolder, file + self.url_suffix))
counter += 1
return { "ui": { "images": paths } }
class PreviewImage(SaveImage):
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
self.url_suffix = "?type=temp"
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
......@@ -954,6 +971,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
......
......@@ -113,7 +113,7 @@ class PromptServer():
async def view_image(request):
if "file" in request.rel_url.query:
type = request.rel_url.query.get("type", "output")
if type != "output" and type != "input":
if type not in ["output", "input", "temp"]:
return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
......@@ -267,7 +267,7 @@ class PromptServer():
msg = await self.messages.get()
await self.send(*msg)
async def start(self, address, port, verbose=True):
async def start(self, address, port, verbose=True, call_on_start=None):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
......@@ -278,3 +278,6 @@ class PromptServer():
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None:
call_on_start(address, port)
......@@ -144,7 +144,14 @@ class ComfyApp {
if (numImages === 1 && !imageIndex) {
this.imageIndex = imageIndex = 0;
}
let shiftY = this.type === "SaveImage" ? 55 : this.imageOffset || 0;
let shiftY;
if (this.imageOffset != null) {
shiftY = this.imageOffset;
} else {
shiftY = this.computeSize()[1];
}
let dw = this.size[0];
let dh = this.size[1];
dh -= shiftY;
......@@ -400,6 +407,15 @@ class ComfyApp {
api.init();
}
#addKeyboardHandler() {
window.addEventListener("keydown", (e) => {
// Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0);
}
});
}
/**
* Loads all extensions from the API into the window
*/
......@@ -466,6 +482,7 @@ class ComfyApp {
this.#addApiUpdateHandlers();
this.#addDropHandler();
this.#addPasteHandler();
this.#addKeyboardHandler();
await this.#invokeExtensionsAsync("setup");
}
......@@ -499,7 +516,11 @@ class ComfyApp {
if (Array.isArray(type)) {
// Enums e.g. latent rotation
this.addWidget("combo", inputName, type[0], () => {}, { values: type });
let defaultValue = type[0];
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
}
this.addWidget("combo", inputName, defaultValue, () => {}, { values: type });
} else if (`${type}:${inputName}` in widgets) {
// Support custom widgets by Type:Name
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment