Commit ba4a754a authored by pythongosssss's avatar pythongosssss
Browse files

Merge remote-tracking branch 'origin/master' into a1111-meta-v2

parents 85989c74 bf1dc1d9
......@@ -9,13 +9,9 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management
try:
if model_management.xformers_enabled():
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError
......@@ -303,6 +299,64 @@ class MemoryEfficientAttnBlock(nn.Module):
out = self.proj_out(out)
return x+out
class MemoryEfficientAttnBlockPytorch(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(B, t.shape[1], 1, C)
.permute(0, 2, 1, 3)
.reshape(B * 1, t.shape[1], C)
.contiguous(),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = (
out.unsqueeze(0)
.reshape(B, 1, out.shape[1], C)
.permute(0, 2, 1, 3)
.reshape(B, out.shape[1], C)
)
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
......@@ -315,8 +369,10 @@ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
if model_management.xformers_enabled() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
......@@ -324,6 +380,8 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
......
......@@ -477,9 +477,9 @@ class UNetModel(nn.Module):
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
# from omegaconf.listconfig import ListConfig
# if type(context_dim) == ListConfig:
# context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
......
......@@ -31,8 +31,25 @@ try:
except:
pass
if "--cpu" in sys.argv:
vram_state = CPU
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
else:
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILBLE = False
if "--lowvram" in sys.argv:
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
......@@ -54,6 +71,8 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
total_vram_available_mb = (total_vram - 1024) // 2
total_vram_available_mb = int(max(256, total_vram_available_mb))
if "--cpu" in sys.argv:
vram_state = CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM"][vram_state])
......@@ -159,6 +178,14 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def xformers_enabled():
if vram_state == CPU:
return False
return XFORMERS_IS_AVAILBLE
def pytorch_attention_enabled():
return ENABLE_PYTORCH_ATTENTION
def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
......
......@@ -6,7 +6,7 @@ import sd2_clip
import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL
from omegaconf import OmegaConf
import yaml
from .cldm import cldm
from .t2i_adapter import adapter
......@@ -726,12 +726,19 @@ def load_clip(ckpt_path, embedding_directory=None):
return clip
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
config = OmegaConf.load(config_path)
with open(config_path, 'r') as stream:
config = yaml.safe_load(stream)
model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
vae_config = model_config_params['first_stage_config']
fp16 = False
if "unet_config" in model_config_params:
if "params" in model_config_params["unet_config"]:
if "use_fp16" in model_config_params["unet_config"]["params"]:
fp16 = model_config_params["unet_config"]["params"]["use_fp16"]
clip = None
vae = None
......@@ -750,9 +757,13 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w.cond_stage_model = clip.cond_stage_model
load_state_dict_to = [w]
model = instantiate_from_config(config.model)
model = instantiate_from_config(config["model"])
sd = load_torch_file(ckpt_path)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae)
......@@ -853,4 +864,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, e
model = instantiate_from_config(model_config)
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
if fp16:
model = model.half()
return (ModelPatcher(model), clip, vae)
import os
from comfy_extras.chainner_models import model_loading
from comfy.sd import load_torch_file
import comfy.model_management
import model_management
from nodes import filter_files_extensions, recursive_search, supported_ckpt_extensions
import torch
import comfy.utils
......@@ -38,7 +38,7 @@ class ImageUpscaleWithModel:
CATEGORY = "image/upscaling"
def upscale(self, upscale_model, image):
device = comfy.model_management.get_torch_device()
device = model_management.get_torch_device()
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale)
......
import os
import sys
import shutil
import threading
import asyncio
......@@ -8,9 +9,6 @@ if os.name == "nt":
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import execution
import server
if __name__ == "__main__":
if '--help' in sys.argv:
print("Valid Command line Arguments:")
......@@ -18,6 +16,8 @@ if __name__ == "__main__":
print("\t--port 8188\t\t\tSet the listen port.")
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--use-pytorch-cross-attention\tUse the new pytorch 2.0 cross attention function.")
print("\t--disable-xformers\t\tdisables xformers")
print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
......@@ -31,6 +31,9 @@ if __name__ == "__main__":
print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16"
import execution
import server
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
while True:
......@@ -38,8 +41,8 @@ def prompt_worker(q, server):
e.execute(item[-2], item[-1])
q.task_done(item_id, e.outputs)
async def run(server, address='', port=8188, verbose=True):
await asyncio.gather(server.start(address, port, verbose), server.publish_loop())
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server):
from tqdm.auto import tqdm
......@@ -51,7 +54,14 @@ def hijack_progress(server):
return v
setattr(tqdm, "update", wrapped_func)
def cleanup_temp():
temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
cleanup_temp()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
......@@ -76,11 +86,22 @@ if __name__ == "__main__":
except:
pass
if '--quick-test-for-ci' in sys.argv:
exit(0)
call_on_start = None
if "--windows-standalone-build" in sys.argv:
def startup_server(address, port):
import webbrowser
webbrowser.open("http://{}:{}".format(address, port))
call_on_start = startup_server
if os.name == "nt":
try:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
except KeyboardInterrupt:
pass
else:
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
cleanup_temp()
......@@ -189,6 +189,7 @@ class VAEEncodeForInpaint:
y = (pixels.shape[2] // 64) * 64
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]
pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y:
pixels = pixels[:,:x,:y,:]
mask = mask[:x,:y]
......@@ -691,8 +692,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
if 'control' in n[1]:
control_nets += [n[1]['control']]
negative_copy += [[t] + n[1:]]
control_net_models = []
......@@ -775,6 +776,7 @@ class KSamplerAdvanced:
class SaveImage:
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
self.url_suffix = ""
@classmethod
def INPUT_TYPES(s):
......@@ -808,6 +810,9 @@ class SaveImage:
os.mkdir(self.output_dir)
counter = 1
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
paths = list()
for image in images:
i = 255. * image.cpu().numpy()
......@@ -820,10 +825,22 @@ class SaveImage:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename_prefix}_{counter:05}_.png"
img.save(os.path.join(self.output_dir, file), pnginfo=metadata, optimize=True)
paths.append(file)
paths.append(file + self.url_suffix)
counter += 1
return { "ui": { "images": paths } }
class PreviewImage(SaveImage):
def __init__(self):
self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
self.url_suffix = "?type=temp"
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
class LoadImage:
input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@classmethod
......@@ -944,6 +961,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
"LoadImageMask": LoadImageMask,
"ImageScale": ImageScale,
......
......@@ -121,7 +121,7 @@ class PromptServer():
async def view_image(request):
if "file" in request.match_info:
type = request.rel_url.query.get("type", "output")
if type != "output" and type != "input":
if type not in ["output", "input", "temp"]:
return web.Response(status=400)
output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), type)
......@@ -268,7 +268,7 @@ class PromptServer():
msg = await self.messages.get()
await self.send(*msg)
async def start(self, address, port, verbose=True):
async def start(self, address, port, verbose=True, call_on_start=None):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, address, port)
......@@ -279,3 +279,6 @@ class PromptServer():
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))
if call_on_start is not None:
call_on_start(address, port)
......@@ -142,7 +142,14 @@ class ComfyApp {
if (numImages === 1 && !imageIndex) {
this.imageIndex = imageIndex = 0;
}
let shiftY = this.type === "SaveImage" ? 55 : this.imageOffset || 0;
let shiftY;
if (this.imageOffset != null) {
shiftY = this.imageOffset;
} else {
shiftY = this.computeSize()[1];
}
let dw = this.size[0];
let dh = this.size[1];
dh -= shiftY;
......@@ -284,9 +291,47 @@ class ComfyApp {
document.addEventListener("drop", async (event) => {
event.preventDefault();
event.stopPropagation();
const file = event.dataTransfer.files[0];
await this.handleFile(file);
const n = this.dragOverNode;
this.dragOverNode = null;
// Node handles file drop, we dont use the built in onDropFile handler as its buggy
// If you drag multiple files it will call it multiple times with the same file
if (n && n.onDragDrop && (await n.onDragDrop(event))) {
return;
}
await this.handleFile(event.dataTransfer.files[0]);
});
// Always clear over node on drag leave
this.canvasEl.addEventListener("dragleave", async () => {
if (this.dragOverNode) {
this.dragOverNode = null;
this.graph.setDirtyCanvas(false, true);
}
});
// Add handler for dropping onto a specific node
this.canvasEl.addEventListener(
"dragover",
(e) => {
this.canvas.adjustMouseEvent(e);
const node = this.graph.getNodeOnPos(e.canvasX, e.canvasY);
if (node) {
if (node.onDragOver && node.onDragOver(e)) {
this.dragOverNode = node;
// dragover event is fired very frequently, run this on an animation frame
requestAnimationFrame(() => {
this.graph.setDirtyCanvas(false, true);
});
return;
}
}
this.dragOverNode = null;
},
false
);
}
/**
......@@ -314,15 +359,22 @@ class ComfyApp {
}
/**
* Draws currently executing node highlight and progress bar
* Draws node highlights (executing, drag drop) and progress bar
*/
#addDrawNodeProgressHandler() {
#addDrawNodeHandler() {
const orig = LGraphCanvas.prototype.drawNodeShape;
const self = this;
LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) {
const res = orig.apply(this, arguments);
if (node.id + "" === self.runningNodeId) {
let color = null;
if (node.id === +self.runningNodeId) {
color = "#0f0";
} else if (self.dragOverNode && node.id === self.dragOverNode.id) {
color = "dodgerblue";
}
if (color) {
const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE;
ctx.lineWidth = 1;
ctx.globalAlpha = 0.8;
......@@ -348,7 +400,7 @@ class ComfyApp {
);
else if (shape == LiteGraph.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
ctx.strokeStyle = "#0f0";
ctx.strokeStyle = color;
ctx.stroke();
ctx.strokeStyle = fgcolor;
ctx.globalAlpha = 1;
......@@ -398,6 +450,15 @@ class ComfyApp {
api.init();
}
#addKeyboardHandler() {
window.addEventListener("keydown", (e) => {
// Queue prompt using ctrl or command + enter
if ((e.ctrlKey || e.metaKey) && (e.key === "Enter" || e.keyCode === 13 || e.keyCode === 10)) {
this.queuePrompt(e.shiftKey ? -1 : 0);
}
});
}
/**
* Loads all extensions from the API into the window
*/
......@@ -419,7 +480,7 @@ class ComfyApp {
await this.#loadExtensions();
// Create and mount the LiteGraph in the DOM
const canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" });
const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" }));
document.body.prepend(canvasEl);
this.graph = new LGraph();
......@@ -460,10 +521,11 @@ class ComfyApp {
// Save current workflow automatically
setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000);
this.#addDrawNodeProgressHandler();
this.#addDrawNodeHandler();
this.#addApiUpdateHandlers();
this.#addDropHandler();
this.#addPasteHandler();
this.#addKeyboardHandler();
await this.#invokeExtensionsAsync("setup");
}
......@@ -497,7 +559,11 @@ class ComfyApp {
if (Array.isArray(type)) {
// Enums e.g. latent rotation
this.addWidget("combo", inputName, type[0], () => {}, { values: type });
let defaultValue = type[0];
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
}
this.addWidget("combo", inputName, defaultValue, () => {}, { values: type });
} else if (`${type}:${inputName}` in widgets) {
// Support custom widgets by Type:Name
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
......@@ -641,31 +707,33 @@ class ComfyApp {
return { workflow, output };
}
async queuePrompt(number) {
const p = await this.graphToPrompt();
async queuePrompt(number, batchCount = 1) {
for (let i = 0; i < batchCount; i++) {
const p = await this.graphToPrompt();
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
return;
}
try {
await api.queuePrompt(number, p);
} catch (error) {
this.ui.dialog.show(error.response || error.toString());
return;
}
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks after a prompt has been queued
// e.g. random seed after every gen
if (widget.afterQueued) {
widget.afterQueued();
for (const n of p.workflow.nodes) {
const node = graph.getNodeById(n.id);
if (node.widgets) {
for (const widget of node.widgets) {
// Allow widgets to run callbacks after a prompt has been queued
// e.g. random seed after every gen
if (widget.afterQueued) {
widget.afterQueued();
}
}
}
}
}
this.canvas.draw(true, true);
await this.ui.queue.update();
this.canvas.draw(true, true);
await this.ui.queue.update();
}
}
/**
......
......@@ -231,6 +231,7 @@ export class ComfyUI {
this.dialog = new ComfyDialog();
this.settings = new ComfySettingsDialog();
this.batchCount = 1;
this.queue = new ComfyList("Queue");
this.history = new ComfyList("History");
......@@ -254,9 +255,35 @@ export class ComfyUI {
$el("span", { $: (q) => (this.queueSize = q) }),
$el("button.comfy-settings-btn", { textContent: "⚙️", onclick: () => this.settings.show() }),
]),
$el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0) }),
$el("button.comfy-queue-btn", { textContent: "Queue Prompt", onclick: () => app.queuePrompt(0, this.batchCount) }),
$el("div", {}, [
$el("label", { innerHTML: "Extra options"}, [
$el("input", { type: "checkbox",
onchange: (i) => {
document.getElementById('extraOptions').style.display = i.srcElement.checked ? "block" : "none";
this.batchCount = i.srcElement.checked ? document.getElementById('batchCountInputRange').value : 1;
}
})
])
]),
$el("div", { id: "extraOptions", style: { width: "100%", display: "none" }}, [
$el("label", { innerHTML: "Batch count" }, [
$el("input", { id: "batchCountInputNumber", type: "number", value: this.batchCount, min: "1", style: { width: "35%", "margin-left": "0.4em" },
oninput: (i) => {
this.batchCount = i.target.value;
document.getElementById('batchCountInputRange').value = this.batchCount;
}
}),
$el("input", { id: "batchCountInputRange", type: "range", min: "1", max: "100", value: this.batchCount,
oninput: (i) => {
this.batchCount = i.srcElement.value;
document.getElementById('batchCountInputNumber').value = i.srcElement.value;
}
}),
]),
]),
$el("div.comfy-menu-btns", [
$el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1) }),
$el("button", { textContent: "Queue Front", onclick: () => app.queuePrompt(-1, this.batchCount) }),
$el("button", {
$: (b) => (this.queue.button = b),
textContent: "View Queue",
......
......@@ -132,7 +132,7 @@ export const ComfyWidgets = {
function showImage(name) {
// Position the image somewhere sensible
if(!node.imageOffset) {
if (!node.imageOffset) {
node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75;
}
......@@ -162,6 +162,36 @@ export const ComfyWidgets = {
}
});
async function uploadFile(file, updateNode) {
try {
// Wrap file in formdata so it includes filename
const body = new FormData();
body.append("image", file);
const resp = await fetch("/upload/image", {
method: "POST",
body,
});
if (resp.status === 200) {
const data = await resp.json();
// Add the file as an option and update the widget value
if (!imageWidget.options.values.includes(data.name)) {
imageWidget.options.values.push(data.name);
}
if (updateNode) {
showImage(data.name);
imageWidget.value = data.name;
}
} else {
alert(resp.status + " - " + resp.statusText);
}
} catch (error) {
alert(error);
}
}
const fileInput = document.createElement("input");
Object.assign(fileInput, {
type: "file",
......@@ -169,30 +199,7 @@ export const ComfyWidgets = {
style: "display: none",
onchange: async () => {
if (fileInput.files.length) {
try {
// Wrap file in formdata so it includes filename
const body = new FormData();
body.append("image", fileInput.files[0]);
const resp = await fetch("/upload/image", {
method: "POST",
body,
});
if (resp.status === 200) {
const data = await resp.json();
showImage(data.name);
// Add the file as an option and update the widget value
if (!imageWidget.options.values.includes(data.name)) {
imageWidget.options.values.push(data.name);
}
imageWidget.value = data.name;
} else {
alert(resp.status + " - " + resp.statusText);
}
} catch (error) {
alert(error);
}
await uploadFile(fileInput.files[0], true);
}
},
});
......@@ -204,6 +211,30 @@ export const ComfyWidgets = {
});
uploadWidget.serialize = false;
// Add handler to check if an image is being dragged over our node
node.onDragOver = function (e) {
if (e.dataTransfer && e.dataTransfer.items) {
const image = [...e.dataTransfer.items].find((f) => f.kind === "file" && f.type.startsWith("image/"));
return !!image;
}
return false;
};
// On drop upload files
node.onDragDrop = function (e) {
console.log("onDragDrop called");
let handled = false;
for (const file of e.dataTransfer.files) {
if (file.type.startsWith("image/")) {
uploadFile(file, !handled); // Dont await these, any order is fine, only update on first one
handled = true;
}
}
return handled;
};
return { widget: uploadWidget };
},
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment