execution.py 29.7 KB
Newer Older
1
2
3
4
import os
import sys
import copy
import json
5
import logging
6
7
8
import threading
import heapq
import traceback
9
import gc
10
import inspect
11
from typing import List, Literal, NamedTuple, Optional
12
13
14
15

import torch
import nodes

16
import comfy.model_management
17

18
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
19
20
21
22
23
24
25
    valid_inputs = class_def.INPUT_TYPES()
    input_data_all = {}
    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
26
            if input_unique_id not in outputs:
27
28
                input_data_all[x] = (None,)
                continue
29
30
31
32
            obj = outputs[input_unique_id][output_index]
            input_data_all[x] = obj
        else:
            if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
33
                input_data_all[x] = [input_data]
34
35
36
37
38

    if "hidden" in valid_inputs:
        h = valid_inputs["hidden"]
        for x in h:
            if h[x] == "PROMPT":
39
                input_data_all[x] = [prompt]
40
41
            if h[x] == "EXTRA_PNGINFO":
                if "extra_pnginfo" in extra_data:
42
                    input_data_all[x] = [extra_data['extra_pnginfo']]
43
            if h[x] == "UNIQUE_ID":
44
                input_data_all[x] = [unique_id]
45
46
    return input_data_all

47
48
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
    # check if node wants the lists
Michael Poutre's avatar
Michael Poutre committed
49
    input_is_list = False
50
    if hasattr(obj, "INPUT_IS_LIST"):
Michael Poutre's avatar
Michael Poutre committed
51
        input_is_list = obj.INPUT_IS_LIST
52

53
54
55
56
    if len(input_data_all) == 0:
        max_len_input = 0
    else:
        max_len_input = max([len(x) for x in input_data_all.values()])
57
58
59
60
61
62
63
64
65
     
    # get a slice of inputs, repeat last input when list isn't long enough
    def slice_dict(d, i):
        d_new = dict()
        for k,v in d.items():
            d_new[k] = v[i if len(v) > i else -1]
        return d_new
    
    results = []
Michael Poutre's avatar
Michael Poutre committed
66
    if input_is_list:
67
68
69
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)(**input_data_all))
70
71
72
73
74
    elif max_len_input == 0:
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)())
    else:
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        for i in range(max_len_input):
            if allow_interrupt:
                nodes.before_node_execution()
            results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
    return results

def get_output_data(obj, input_data_all):
    
    results = []
    uis = []
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

    for r in return_values:
        if isinstance(r, dict):
            if 'ui' in r:
                uis.append(r['ui'])
            if 'result' in r:
                results.append(r['result'])
        else:
            results.append(r)
    
    output = []
    if len(results) > 0:
        # check which outputs need concatenating
        output_is_list = [False] * len(results[0])
        if hasattr(obj, "OUTPUT_IS_LIST"):
            output_is_list = obj.OUTPUT_IS_LIST

        # merge node execution results
        for i, is_list in zip(range(len(results[0])), output_is_list):
            if is_list:
                output.append([x for o in results for x in o[i]])
            else:
                output.append([o[i] for o in results])

    ui = dict()    
    if len(uis) > 0:
        ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
    return output, ui

115
def format_value(x):
space-nuko's avatar
space-nuko committed
116
117
118
    if x is None:
        return None
    elif isinstance(x, (int, float, bool, str)):
119
120
121
122
        return x
    else:
        return str(x)

123
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
124
125
126
127
128
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
    if unique_id in outputs:
129
        return (True, None, None)
130
131
132
133
134
135
136
137

    for x in inputs:
        input_data = inputs[x]

        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
138
                result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
139
140
141
142
143
144
145
                if result[0] is not True:
                    # Another node failed further upstream
                    return result

    input_data_all = None
    try:
        input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
146
        if server.client_id is not None:
147
148
            server.last_node_id = unique_id
            server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
149
150
151
152
153

        obj = object_storage.get((unique_id, class_type), None)
        if obj is None:
            obj = class_def()
            object_storage[(unique_id, class_type)] = obj
154
155
156
157
158
159
160
161

        output_data, output_ui = get_output_data(obj, input_data_all)
        outputs[unique_id] = output_data
        if len(output_ui) > 0:
            outputs_ui[unique_id] = output_ui
            if server.client_id is not None:
                server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
    except comfy.model_management.InterruptProcessingException as iex:
162
        logging.info("Processing interrupted")
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

        # skip formatting inputs/outputs
        error_details = {
            "node_id": unique_id,
        }

        return (False, error_details, iex)
    except Exception as ex:
        typ, _, tb = sys.exc_info()
        exception_type = full_type_name(typ)
        input_data_formatted = {}
        if input_data_all is not None:
            input_data_formatted = {}
            for name, inputs in input_data_all.items():
                input_data_formatted[name] = [format_value(x) for x in inputs]

        output_data_formatted = {}
        for node_id, node_outputs in outputs.items():
            output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]

183
184
        logging.error("!!! Exception during processing !!!")
        logging.error(traceback.format_exc())
185
186
187

        error_details = {
            "node_id": unique_id,
space-nuko's avatar
space-nuko committed
188
            "exception_message": str(ex),
189
190
191
192
193
194
195
            "exception_type": exception_type,
            "traceback": traceback.format_tb(tb),
            "current_inputs": input_data_formatted,
            "current_outputs": output_data_formatted
        }
        return (False, error_details, ex)

196
    executed.add(unique_id)
197

198
199
    return (True, None, None)

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def recursive_will_execute(prompt, outputs, current_item):
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
                will_execute += recursive_will_execute(prompt, outputs, input_unique_id)

    return will_execute + [unique_id]

def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]

    is_changed_old = ''
    is_changed = ''
225
    to_delete = False
226
227
228
229
    if hasattr(class_def, 'IS_CHANGED'):
        if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
            is_changed_old = old_prompt[unique_id]['is_changed']
        if 'is_changed' not in prompt[unique_id]:
230
            input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
231
            if input_data_all is not None:
232
                try:
233
234
                    #is_changed = class_def.IS_CHANGED(**input_data_all)
                    is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
235
236
237
                    prompt[unique_id]['is_changed'] = is_changed
                except:
                    to_delete = True
238
239
240
241
242
243
        else:
            is_changed = prompt[unique_id]['is_changed']

    if unique_id not in outputs:
        return True

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    if not to_delete:
        if is_changed != is_changed_old:
            to_delete = True
        elif unique_id not in old_prompt:
            to_delete = True
        elif inputs == old_prompt[unique_id]['inputs']:
            for x in inputs:
                input_data = inputs[x]

                if isinstance(input_data, list):
                    input_unique_id = input_data[0]
                    output_index = input_data[1]
                    if input_unique_id in outputs:
                        to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
                    else:
                        to_delete = True
                    if to_delete:
                        break
        else:
            to_delete = True
264
265
266
267
268
269
270
271

    if to_delete:
        d = outputs.pop(unique_id)
        del d
    return to_delete

class PromptExecutor:
    def __init__(self, server):
272
273
274
275
        self.server = server
        self.reset()

    def reset(self):
276
        self.outputs = {}
277
        self.object_storage = {}
278
        self.outputs_ui = {}
279
        self.status_messages = []
280
        self.success = True
281
282
        self.old_prompt = {}

283
284
    def add_message(self, event, data, broadcast: bool):
        self.status_messages.append((event, data))
285
286
287
        if self.server.client_id is not None or broadcast:
            self.server.send_sync(event, data, self.server.client_id)

288
289
290
291
    def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
        node_id = error["node_id"]
        class_type = prompt[node_id]["class_type"]

292
293
294
295
296
        # First, send back the status to the frontend depending
        # on the exception type
        if isinstance(ex, comfy.model_management.InterruptProcessingException):
            mes = {
                "prompt_id": prompt_id,
297
298
                "node_id": node_id,
                "node_type": class_type,
299
300
                "executed": list(executed),
            }
301
            self.add_message("execution_interrupted", mes, broadcast=True)
302
        else:
303
304
305
306
307
            mes = {
                "prompt_id": prompt_id,
                "node_id": node_id,
                "node_type": class_type,
                "executed": list(executed),
308

309
310
311
312
313
314
                "exception_message": error["exception_message"],
                "exception_type": error["exception_type"],
                "traceback": error["traceback"],
                "current_inputs": error["current_inputs"],
                "current_outputs": error["current_outputs"],
            }
315
            self.add_message("execution_error", mes, broadcast=False)
316
        
317
318
319
320
321
322
323
324
325
326
327
328
        # Next, remove the subsequent outputs since they will not be executed
        to_delete = []
        for o in self.outputs:
            if (o not in current_outputs) and (o not in executed):
                to_delete += [o]
                if o in self.old_prompt:
                    d = self.old_prompt.pop(o)
                    del d
        for o in to_delete:
            d = self.outputs.pop(o)
            del d

329
    def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
330
331
        nodes.interrupt_processing(False)

332
333
334
335
336
        if "client_id" in extra_data:
            self.server.client_id = extra_data["client_id"]
        else:
            self.server.client_id = None

337
338
        self.status_messages = []
        self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
339

340
        with torch.inference_mode():
341
342
343
344
345
346
347
348
            #delete cached outputs if nodes don't exist for them
            to_delete = []
            for o in self.outputs:
                if o not in prompt:
                    to_delete += [o]
            for o in to_delete:
                d = self.outputs.pop(o)
                del d
349
350
351
352
353
354
355
356
357
358
359
            to_delete = []
            for o in self.object_storage:
                if o[0] not in prompt:
                    to_delete += [o]
                else:
                    p = prompt[o[0]]
                    if o[1] != p['class_type']:
                        to_delete += [o]
            for o in to_delete:
                d = self.object_storage.pop(o)
                del d
360

361
362
363
364
            for x in prompt:
                recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)

            current_outputs = set(self.outputs.keys())
365
366
367
368
369
            for x in list(self.outputs_ui.keys()):
                if x not in current_outputs:
                    d = self.outputs_ui.pop(x)
                    del d

comfyanonymous's avatar
comfyanonymous committed
370
            comfy.model_management.cleanup_models()
371
            self.add_message("execution_cached",
372
373
                          { "nodes": list(current_outputs) , "prompt_id": prompt_id},
                          broadcast=False)
374
            executed = set()
375
376
377
378
379
380
381
382
383
384
385
386
387
388
            output_node_id = None
            to_execute = []

            for node_id in list(execute_outputs):
                to_execute += [(0, node_id)]

            while len(to_execute) > 0:
                #always execute the output that depends on the least amount of unexecuted nodes first
                to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
                output_node_id = to_execute.pop(0)[-1]

                # This call shouldn't raise anything if there's an error deep in
                # the actual SD code, instead it will report the node where the
                # error was raised
389
390
                self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
                if self.success is not True:
391
                    self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
392
                    break
393
394
395
396

            for x in executed:
                self.old_prompt[x] = copy.deepcopy(prompt[x])
            self.server.last_node_id = None
397
398
            if comfy.model_management.DISABLE_SMART_MEMORY:
                comfy.model_management.unload_all_models()
399

400

401

402
def validate_inputs(prompt, item, validated):
403
    unique_id = item
404
405
406
    if unique_id in validated:
        return validated[unique_id]

407
408
409
410
411
412
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]

    class_inputs = obj_class.INPUT_TYPES()
    required_inputs = class_inputs['required']
413
414
415
416

    errors = []
    valid = True

417
418
419
420
    validate_function_inputs = []
    if hasattr(obj_class, "VALIDATE_INPUTS"):
        validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args

421
422
    for x in required_inputs:
        if x not in inputs:
423
424
425
426
427
428
429
430
431
432
433
            error = {
                "type": "required_input_missing",
                "message": "Required input is missing",
                "details": f"{x}",
                "extra_info": {
                    "input_name": x
                }
            }
            errors.append(error)
            continue

434
435
436
437
438
        val = inputs[x]
        info = required_inputs[x]
        type_input = info[0]
        if isinstance(val, list):
            if len(val) != 2:
439
440
441
442
443
444
445
446
447
448
449
450
451
                error = {
                    "type": "bad_linked_input",
                    "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
                    "details": f"{x}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val
                    }
                }
                errors.append(error)
                continue

452
453
454
455
            o_id = val[0]
            o_class_type = prompt[o_id]['class_type']
            r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
            if r[val[1]] != type_input:
456
457
458
459
460
461
462
463
464
                received_type = r[val[1]]
                details = f"{x}, {received_type} != {type_input}"
                error = {
                    "type": "return_type_mismatch",
                    "message": "Return type mismatch between linked nodes",
                    "details": details,
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
465
466
                        "received_type": received_type,
                        "linked_node": val
467
468
469
470
471
472
473
474
475
476
477
478
479
                    }
                }
                errors.append(error)
                continue
            try:
                r = validate_inputs(prompt, o_id, validated)
                if r[0] is False:
                    # `r` will be set in `validated[o_id]` already
                    valid = False
                    continue
            except Exception as ex:
                typ, _, tb = sys.exc_info()
                valid = False
480
                exception_type = full_type_name(typ)
481
                reasons = [{
space-nuko's avatar
space-nuko committed
482
483
                    "type": "exception_during_inner_validation",
                    "message": "Exception when validating inner node",
484
485
                    "details": str(ex),
                    "extra_info": {
space-nuko's avatar
space-nuko committed
486
487
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
488
                        "exception_message": str(ex),
489
                        "exception_type": exception_type,
space-nuko's avatar
space-nuko committed
490
491
                        "traceback": traceback.format_tb(tb),
                        "linked_node": val
492
493
494
495
                    }
                }]
                validated[o_id] = (False, reasons, o_id)
                continue
496
        else:
space-nuko's avatar
space-nuko committed
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
            try:
                if type_input == "INT":
                    val = int(val)
                    inputs[x] = val
                if type_input == "FLOAT":
                    val = float(val)
                    inputs[x] = val
                if type_input == "STRING":
                    val = str(val)
                    inputs[x] = val
            except Exception as ex:
                error = {
                    "type": "invalid_input_type",
                    "message": f"Failed to convert an input value to a {type_input} value",
                    "details": f"{x}, {val}, {ex}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val,
                        "exception_message": str(ex)
                    }
                }
                errors.append(error)
                continue
521
522
523

            if len(info) > 1:
                if "min" in info[1] and val < info[1]["min"]:
524
525
526
527
528
529
530
531
532
533
534
535
                    error = {
                        "type": "value_smaller_than_min",
                        "message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
536
                if "max" in info[1] and val > info[1]["max"]:
537
538
539
540
541
542
543
544
545
546
547
548
                    error = {
                        "type": "value_bigger_than_max",
                        "message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
549

550
            if x not in validate_function_inputs:
551
552
                if isinstance(type_input, list):
                    if val not in type_input:
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                        input_config = info
                        list_info = ""

                        # Don't send back gigantic lists like if they're lots of
                        # scanned model filepaths
                        if len(type_input) > 20:
                            list_info = f"(list of length {len(type_input)})"
                            input_config = None
                        else:
                            list_info = str(type_input)

                        error = {
                            "type": "value_not_in_list",
                            "message": "Value not in list",
                            "details": f"{x}: '{val}' not in {list_info}",
                            "extra_info": {
                                "input_name": x,
                                "input_config": input_config,
                                "received_value": val,
                            }
                        }
                        errors.append(error)
                        continue

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    if len(validate_function_inputs) > 0:
        input_data_all = get_input_data(inputs, obj_class, unique_id)
        input_filtered = {}
        for x in input_data_all:
            if x in validate_function_inputs:
                input_filtered[x] = input_data_all[x]

        #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
        ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
        for x in input_filtered:
            for i, r in enumerate(ret):
                if r is not True:
                    details = f"{x}"
                    if r is not False:
                        details += f" - {str(r)}"

                    error = {
                        "type": "custom_validation_failed",
                        "message": "Custom validation failed for node",
                        "details": details,
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue

606
607
608
609
    if len(errors) > 0 or valid is not True:
        ret = (False, errors, unique_id)
    else:
        ret = (True, [], unique_id)
610
611
612

    validated[unique_id] = ret
    return ret
613

614
615
616
617
618
619
def full_type_name(klass):
    module = klass.__module__
    if module == 'builtins':
        return klass.__qualname__
    return module + '.' + klass.__qualname__

620
621
622
623
624
625
626
627
def validate_prompt(prompt):
    outputs = set()
    for x in prompt:
        class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
        if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
            outputs.add(x)

    if len(outputs) == 0:
628
629
630
631
632
633
634
        error = {
            "type": "prompt_no_outputs",
            "message": "Prompt has no outputs",
            "details": "",
            "extra_info": {}
        }
        return (False, error, [], [])
635
636
637

    good_outputs = set()
    errors = []
638
    node_errors = {}
639
    validated = {}
640
641
    for o in outputs:
        valid = False
642
        reasons = []
643
        try:
644
            m = validate_inputs(prompt, o, validated)
645
            valid = m[0]
646
647
648
            reasons = m[1]
        except Exception as ex:
            typ, _, tb = sys.exc_info()
649
            valid = False
650
            exception_type = full_type_name(typ)
651
652
653
654
655
            reasons = [{
                "type": "exception_during_validation",
                "message": "Exception when validating node",
                "details": str(ex),
                "extra_info": {
656
                    "exception_type": exception_type,
657
658
659
660
661
662
                    "traceback": traceback.format_tb(tb)
                }
            }]
            validated[o] = (False, reasons, o)

        if valid is True:
663
            good_outputs.add(o)
664
        else:
665
            logging.error(f"Failed to validate prompt for output {o}:")
666
            if len(reasons) > 0:
667
                logging.error("* (prompt):")
668
                for reason in reasons:
669
                    logging.error(f"  - {reason['message']}: {reason['details']}")
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
            errors += [(o, reasons)]
            for node_id, result in validated.items():
                valid = result[0]
                reasons = result[1]
                # If a node upstream has errors, the nodes downstream will also
                # be reported as invalid, but there will be no errors attached.
                # So don't return those nodes as having errors in the response.
                if valid is not True and len(reasons) > 0:
                    if node_id not in node_errors:
                        class_type = prompt[node_id]['class_type']
                        node_errors[node_id] = {
                            "errors": reasons,
                            "dependent_outputs": [],
                            "class_type": class_type
                        }
685
                        logging.error(f"* {class_type} {node_id}:")
686
                        for reason in reasons:
687
                            logging.error(f"  - {reason['message']}: {reason['details']}")
688
                    node_errors[node_id]["dependent_outputs"].append(o)
689
            logging.error("Output will be ignored")
690
691

    if len(good_outputs) == 0:
692
693
694
695
696
697
698
        errors_list = []
        for o, errors in errors:
            for error in errors:
                errors_list.append(f"{error['message']}: {error['details']}")
        errors_list = "\n".join(errors_list)

        error = {
699
700
            "type": "prompt_outputs_failed_validation",
            "message": "Prompt outputs failed validation",
701
702
703
704
705
706
707
            "details": errors_list,
            "extra_info": {}
        }

        return (False, error, list(good_outputs), node_errors)

    return (True, None, list(good_outputs), node_errors)
708

709
MAXIMUM_HISTORY_SIZE = 10000
710
711
712
713
714
715
716
717
718
719

class PromptQueue:
    def __init__(self, server):
        self.server = server
        self.mutex = threading.RLock()
        self.not_empty = threading.Condition(self.mutex)
        self.task_counter = 0
        self.queue = []
        self.currently_running = {}
        self.history = {}
720
        self.flags = {}
721
722
723
724
725
726
727
728
        server.prompt_queue = self

    def put(self, item):
        with self.mutex:
            heapq.heappush(self.queue, item)
            self.server.queue_updated()
            self.not_empty.notify()

729
    def get(self, timeout=None):
730
731
        with self.not_empty:
            while len(self.queue) == 0:
732
733
734
                self.not_empty.wait(timeout=timeout)
                if timeout is not None and len(self.queue) == 0:
                    return None
735
736
737
738
739
740
741
            item = heapq.heappop(self.queue)
            i = self.task_counter
            self.currently_running[i] = copy.deepcopy(item)
            self.task_counter += 1
            self.server.queue_updated()
            return (item, i)

742
743
744
    class ExecutionStatus(NamedTuple):
        status_str: Literal['success', 'error']
        completed: bool
745
        messages: List[str]
746
747
748

    def task_done(self, item_id, outputs,
                  status: Optional['PromptQueue.ExecutionStatus']):
749
750
        with self.mutex:
            prompt = self.currently_running.pop(item_id)
751
752
            if len(self.history) > MAXIMUM_HISTORY_SIZE:
                self.history.pop(next(iter(self.history)))
753

754
            status_dict: Optional[dict] = None
755
756
757
758
759
760
761
762
            if status is not None:
                status_dict = copy.deepcopy(status._asdict())

            self.history[prompt[1]] = {
                "prompt": prompt,
                "outputs": copy.deepcopy(outputs),
                'status': status_dict,
            }
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
            self.server.queue_updated()

    def get_current_queue(self):
        with self.mutex:
            out = []
            for x in self.currently_running.values():
                out += [x]
            return (out, copy.deepcopy(self.queue))

    def get_tasks_remaining(self):
        with self.mutex:
            return len(self.queue) + len(self.currently_running)

    def wipe_queue(self):
        with self.mutex:
            self.queue = []
            self.server.queue_updated()

    def delete_queue_item(self, function):
        with self.mutex:
            for x in range(len(self.queue)):
                if function(self.queue[x]):
                    if len(self.queue) == 1:
                        self.wipe_queue()
                    else:
                        self.queue.pop(x)
                        heapq.heapify(self.queue)
                    self.server.queue_updated()
                    return True
        return False

794
    def get_history(self, prompt_id=None, max_items=None, offset=-1):
795
        with self.mutex:
796
            if prompt_id is None:
797
798
799
800
801
802
803
804
805
806
807
                out = {}
                i = 0
                if offset < 0 and max_items is not None:
                    offset = len(self.history) - max_items
                for k in self.history:
                    if i >= offset:
                        out[k] = self.history[k]
                        if max_items is not None and len(out) >= max_items:
                            break
                    i += 1
                return out
808
809
810
811
            elif prompt_id in self.history:
                return {prompt_id: copy.deepcopy(self.history[prompt_id])}
            else:
                return {}
812
813
814
815
816
817
818
819

    def wipe_history(self):
        with self.mutex:
            self.history = {}

    def delete_history_item(self, id_to_delete):
        with self.mutex:
            self.history.pop(id_to_delete, None)
820
821
822
823
824
825
826
827
828
829
830
831
832
833

    def set_flag(self, name, data):
        with self.mutex:
            self.flags[name] = data
            self.not_empty.notify()

    def get_flags(self, reset=True):
        with self.mutex:
            if reset:
                ret = self.flags
                self.flags = {}
                return ret
            else:
                return self.flags.copy()