execution.py 30.1 KB
Newer Older
1
2
import sys
import copy
3
import logging
4
5
6
import threading
import heapq
import traceback
7
import inspect
8
from typing import List, Literal, NamedTuple, Optional
9
10
11
12

import torch
import nodes

13
import comfy.model_management
14

15
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
16
17
18
19
20
21
22
    valid_inputs = class_def.INPUT_TYPES()
    input_data_all = {}
    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
23
            if input_unique_id not in outputs:
24
25
                input_data_all[x] = (None,)
                continue
26
27
28
29
            obj = outputs[input_unique_id][output_index]
            input_data_all[x] = obj
        else:
            if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
30
                input_data_all[x] = [input_data]
31
32
33
34
35

    if "hidden" in valid_inputs:
        h = valid_inputs["hidden"]
        for x in h:
            if h[x] == "PROMPT":
36
                input_data_all[x] = [prompt]
37
            if h[x] == "EXTRA_PNGINFO":
38
                input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
39
            if h[x] == "UNIQUE_ID":
40
                input_data_all[x] = [unique_id]
41
42
    return input_data_all

43
44
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
    # check if node wants the lists
Michael Poutre's avatar
Michael Poutre committed
45
    input_is_list = False
46
    if hasattr(obj, "INPUT_IS_LIST"):
Michael Poutre's avatar
Michael Poutre committed
47
        input_is_list = obj.INPUT_IS_LIST
48

49
50
51
52
    if len(input_data_all) == 0:
        max_len_input = 0
    else:
        max_len_input = max([len(x) for x in input_data_all.values()])
53
54
55
56
57
58
59
60
61
     
    # get a slice of inputs, repeat last input when list isn't long enough
    def slice_dict(d, i):
        d_new = dict()
        for k,v in d.items():
            d_new[k] = v[i if len(v) > i else -1]
        return d_new
    
    results = []
Michael Poutre's avatar
Michael Poutre committed
62
    if input_is_list:
63
64
65
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)(**input_data_all))
66
67
68
69
70
    elif max_len_input == 0:
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)())
    else:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        for i in range(max_len_input):
            if allow_interrupt:
                nodes.before_node_execution()
            results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
    return results

def get_output_data(obj, input_data_all):
    
    results = []
    uis = []
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

    for r in return_values:
        if isinstance(r, dict):
            if 'ui' in r:
                uis.append(r['ui'])
            if 'result' in r:
                results.append(r['result'])
        else:
            results.append(r)
    
    output = []
    if len(results) > 0:
        # check which outputs need concatenating
        output_is_list = [False] * len(results[0])
        if hasattr(obj, "OUTPUT_IS_LIST"):
            output_is_list = obj.OUTPUT_IS_LIST

        # merge node execution results
        for i, is_list in zip(range(len(results[0])), output_is_list):
            if is_list:
                output.append([x for o in results for x in o[i]])
            else:
                output.append([o[i] for o in results])

    ui = dict()    
    if len(uis) > 0:
        ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
    return output, ui

111
def format_value(x):
space-nuko's avatar
space-nuko committed
112
113
114
    if x is None:
        return None
    elif isinstance(x, (int, float, bool, str)):
115
116
117
118
        return x
    else:
        return str(x)

119
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
120
121
122
123
124
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
    if unique_id in outputs:
125
        return (True, None, None)
126
127
128
129
130
131
132
133

    for x in inputs:
        input_data = inputs[x]

        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
134
                result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
135
136
137
138
139
140
141
                if result[0] is not True:
                    # Another node failed further upstream
                    return result

    input_data_all = None
    try:
        input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
142
        if server.client_id is not None:
143
144
            server.last_node_id = unique_id
            server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
145
146
147
148
149

        obj = object_storage.get((unique_id, class_type), None)
        if obj is None:
            obj = class_def()
            object_storage[(unique_id, class_type)] = obj
150
151
152
153
154
155
156
157

        output_data, output_ui = get_output_data(obj, input_data_all)
        outputs[unique_id] = output_data
        if len(output_ui) > 0:
            outputs_ui[unique_id] = output_ui
            if server.client_id is not None:
                server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
    except comfy.model_management.InterruptProcessingException as iex:
158
        logging.info("Processing interrupted")
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

        # skip formatting inputs/outputs
        error_details = {
            "node_id": unique_id,
        }

        return (False, error_details, iex)
    except Exception as ex:
        typ, _, tb = sys.exc_info()
        exception_type = full_type_name(typ)
        input_data_formatted = {}
        if input_data_all is not None:
            input_data_formatted = {}
            for name, inputs in input_data_all.items():
                input_data_formatted[name] = [format_value(x) for x in inputs]

        output_data_formatted = {}
        for node_id, node_outputs in outputs.items():
            output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]

179
        logging.error(f"!!! Exception during processing!!! {ex}")
180
        logging.error(traceback.format_exc())
181
182
183

        error_details = {
            "node_id": unique_id,
space-nuko's avatar
space-nuko committed
184
            "exception_message": str(ex),
185
186
187
188
189
190
191
            "exception_type": exception_type,
            "traceback": traceback.format_tb(tb),
            "current_inputs": input_data_formatted,
            "current_outputs": output_data_formatted
        }
        return (False, error_details, ex)

192
    executed.add(unique_id)
193

194
195
    return (True, None, None)

196
def recursive_will_execute(prompt, outputs, current_item, memo={}):
197
    unique_id = current_item
198
199
200
201

    if unique_id in memo:
        return memo[unique_id]

202
203
204
205
206
207
208
209
210
211
212
    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
213
                will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
214

215
216
    memo[unique_id] = will_execute + [unique_id]
    return memo[unique_id]
217
218
219
220
221
222
223
224
225

def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]

    is_changed_old = ''
    is_changed = ''
226
    to_delete = False
227
228
229
230
    if hasattr(class_def, 'IS_CHANGED'):
        if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
            is_changed_old = old_prompt[unique_id]['is_changed']
        if 'is_changed' not in prompt[unique_id]:
231
            input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
232
            if input_data_all is not None:
233
                try:
234
235
                    #is_changed = class_def.IS_CHANGED(**input_data_all)
                    is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
236
237
238
                    prompt[unique_id]['is_changed'] = is_changed
                except:
                    to_delete = True
239
240
241
242
243
244
        else:
            is_changed = prompt[unique_id]['is_changed']

    if unique_id not in outputs:
        return True

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    if not to_delete:
        if is_changed != is_changed_old:
            to_delete = True
        elif unique_id not in old_prompt:
            to_delete = True
        elif inputs == old_prompt[unique_id]['inputs']:
            for x in inputs:
                input_data = inputs[x]

                if isinstance(input_data, list):
                    input_unique_id = input_data[0]
                    output_index = input_data[1]
                    if input_unique_id in outputs:
                        to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
                    else:
                        to_delete = True
                    if to_delete:
                        break
        else:
            to_delete = True
265
266
267
268
269
270
271
272

    if to_delete:
        d = outputs.pop(unique_id)
        del d
    return to_delete

class PromptExecutor:
    def __init__(self, server):
273
274
275
276
        self.server = server
        self.reset()

    def reset(self):
277
        self.outputs = {}
278
        self.object_storage = {}
279
        self.outputs_ui = {}
280
        self.status_messages = []
281
        self.success = True
282
283
        self.old_prompt = {}

284
285
    def add_message(self, event, data, broadcast: bool):
        self.status_messages.append((event, data))
286
287
288
        if self.server.client_id is not None or broadcast:
            self.server.send_sync(event, data, self.server.client_id)

289
290
291
292
    def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
        node_id = error["node_id"]
        class_type = prompt[node_id]["class_type"]

293
294
295
296
297
        # First, send back the status to the frontend depending
        # on the exception type
        if isinstance(ex, comfy.model_management.InterruptProcessingException):
            mes = {
                "prompt_id": prompt_id,
298
299
                "node_id": node_id,
                "node_type": class_type,
300
301
                "executed": list(executed),
            }
302
            self.add_message("execution_interrupted", mes, broadcast=True)
303
        else:
304
305
306
307
308
            mes = {
                "prompt_id": prompt_id,
                "node_id": node_id,
                "node_type": class_type,
                "executed": list(executed),
309

310
311
312
313
314
315
                "exception_message": error["exception_message"],
                "exception_type": error["exception_type"],
                "traceback": error["traceback"],
                "current_inputs": error["current_inputs"],
                "current_outputs": error["current_outputs"],
            }
316
            self.add_message("execution_error", mes, broadcast=False)
317
        
318
319
320
321
322
323
324
325
326
327
328
329
        # Next, remove the subsequent outputs since they will not be executed
        to_delete = []
        for o in self.outputs:
            if (o not in current_outputs) and (o not in executed):
                to_delete += [o]
                if o in self.old_prompt:
                    d = self.old_prompt.pop(o)
                    del d
        for o in to_delete:
            d = self.outputs.pop(o)
            del d

330
    def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
331
332
        nodes.interrupt_processing(False)

333
334
335
336
337
        if "client_id" in extra_data:
            self.server.client_id = extra_data["client_id"]
        else:
            self.server.client_id = None

338
339
        self.status_messages = []
        self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
340

341
        with torch.inference_mode():
342
343
344
345
346
347
348
349
            #delete cached outputs if nodes don't exist for them
            to_delete = []
            for o in self.outputs:
                if o not in prompt:
                    to_delete += [o]
            for o in to_delete:
                d = self.outputs.pop(o)
                del d
350
351
352
353
354
355
356
357
358
359
360
            to_delete = []
            for o in self.object_storage:
                if o[0] not in prompt:
                    to_delete += [o]
                else:
                    p = prompt[o[0]]
                    if o[1] != p['class_type']:
                        to_delete += [o]
            for o in to_delete:
                d = self.object_storage.pop(o)
                del d
361

362
363
364
365
            for x in prompt:
                recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)

            current_outputs = set(self.outputs.keys())
366
367
368
369
370
            for x in list(self.outputs_ui.keys()):
                if x not in current_outputs:
                    d = self.outputs_ui.pop(x)
                    del d

371
            comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
372
            self.add_message("execution_cached",
373
374
                          { "nodes": list(current_outputs) , "prompt_id": prompt_id},
                          broadcast=False)
375
            executed = set()
376
377
378
379
380
381
382
383
            output_node_id = None
            to_execute = []

            for node_id in list(execute_outputs):
                to_execute += [(0, node_id)]

            while len(to_execute) > 0:
                #always execute the output that depends on the least amount of unexecuted nodes first
384
385
                memo = {}
                to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
386
387
388
389
390
                output_node_id = to_execute.pop(0)[-1]

                # This call shouldn't raise anything if there's an error deep in
                # the actual SD code, instead it will report the node where the
                # error was raised
391
392
                self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
                if self.success is not True:
393
                    self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
394
                    break
395
396
397
398

            for x in executed:
                self.old_prompt[x] = copy.deepcopy(prompt[x])
            self.server.last_node_id = None
399
400
            if comfy.model_management.DISABLE_SMART_MEMORY:
                comfy.model_management.unload_all_models()
401

402

403

404
def validate_inputs(prompt, item, validated):
405
    unique_id = item
406
407
408
    if unique_id in validated:
        return validated[unique_id]

409
410
411
412
413
414
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]

    class_inputs = obj_class.INPUT_TYPES()
    required_inputs = class_inputs['required']
415
416
417
418

    errors = []
    valid = True

419
420
421
422
    validate_function_inputs = []
    if hasattr(obj_class, "VALIDATE_INPUTS"):
        validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args

423
424
    for x in required_inputs:
        if x not in inputs:
425
426
427
428
429
430
431
432
433
434
435
            error = {
                "type": "required_input_missing",
                "message": "Required input is missing",
                "details": f"{x}",
                "extra_info": {
                    "input_name": x
                }
            }
            errors.append(error)
            continue

436
437
438
439
440
        val = inputs[x]
        info = required_inputs[x]
        type_input = info[0]
        if isinstance(val, list):
            if len(val) != 2:
441
442
443
444
445
446
447
448
449
450
451
452
453
                error = {
                    "type": "bad_linked_input",
                    "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
                    "details": f"{x}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val
                    }
                }
                errors.append(error)
                continue

454
455
456
457
            o_id = val[0]
            o_class_type = prompt[o_id]['class_type']
            r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
            if r[val[1]] != type_input:
458
459
460
461
462
463
464
465
466
                received_type = r[val[1]]
                details = f"{x}, {received_type} != {type_input}"
                error = {
                    "type": "return_type_mismatch",
                    "message": "Return type mismatch between linked nodes",
                    "details": details,
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
467
468
                        "received_type": received_type,
                        "linked_node": val
469
470
471
472
473
474
475
476
477
478
479
480
481
                    }
                }
                errors.append(error)
                continue
            try:
                r = validate_inputs(prompt, o_id, validated)
                if r[0] is False:
                    # `r` will be set in `validated[o_id]` already
                    valid = False
                    continue
            except Exception as ex:
                typ, _, tb = sys.exc_info()
                valid = False
482
                exception_type = full_type_name(typ)
483
                reasons = [{
space-nuko's avatar
space-nuko committed
484
485
                    "type": "exception_during_inner_validation",
                    "message": "Exception when validating inner node",
486
487
                    "details": str(ex),
                    "extra_info": {
space-nuko's avatar
space-nuko committed
488
489
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
490
                        "exception_message": str(ex),
491
                        "exception_type": exception_type,
space-nuko's avatar
space-nuko committed
492
493
                        "traceback": traceback.format_tb(tb),
                        "linked_node": val
494
495
496
497
                    }
                }]
                validated[o_id] = (False, reasons, o_id)
                continue
498
        else:
space-nuko's avatar
space-nuko committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
            try:
                if type_input == "INT":
                    val = int(val)
                    inputs[x] = val
                if type_input == "FLOAT":
                    val = float(val)
                    inputs[x] = val
                if type_input == "STRING":
                    val = str(val)
                    inputs[x] = val
            except Exception as ex:
                error = {
                    "type": "invalid_input_type",
                    "message": f"Failed to convert an input value to a {type_input} value",
                    "details": f"{x}, {val}, {ex}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val,
                        "exception_message": str(ex)
                    }
                }
                errors.append(error)
                continue
523
524
525

            if len(info) > 1:
                if "min" in info[1] and val < info[1]["min"]:
526
527
528
529
530
531
532
533
534
535
536
537
                    error = {
                        "type": "value_smaller_than_min",
                        "message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
538
                if "max" in info[1] and val > info[1]["max"]:
539
540
541
542
543
544
545
546
547
548
549
550
                    error = {
                        "type": "value_bigger_than_max",
                        "message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
551

552
            if x not in validate_function_inputs:
553
554
                if isinstance(type_input, list):
                    if val not in type_input:
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
                        input_config = info
                        list_info = ""

                        # Don't send back gigantic lists like if they're lots of
                        # scanned model filepaths
                        if len(type_input) > 20:
                            list_info = f"(list of length {len(type_input)})"
                            input_config = None
                        else:
                            list_info = str(type_input)

                        error = {
                            "type": "value_not_in_list",
                            "message": "Value not in list",
                            "details": f"{x}: '{val}' not in {list_info}",
                            "extra_info": {
                                "input_name": x,
                                "input_config": input_config,
                                "received_value": val,
                            }
                        }
                        errors.append(error)
                        continue

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    if len(validate_function_inputs) > 0:
        input_data_all = get_input_data(inputs, obj_class, unique_id)
        input_filtered = {}
        for x in input_data_all:
            if x in validate_function_inputs:
                input_filtered[x] = input_data_all[x]

        #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
        ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
        for x in input_filtered:
            for i, r in enumerate(ret):
                if r is not True:
                    details = f"{x}"
                    if r is not False:
                        details += f" - {str(r)}"

                    error = {
                        "type": "custom_validation_failed",
                        "message": "Custom validation failed for node",
                        "details": details,
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue

608
609
610
611
    if len(errors) > 0 or valid is not True:
        ret = (False, errors, unique_id)
    else:
        ret = (True, [], unique_id)
612
613
614

    validated[unique_id] = ret
    return ret
615

616
617
618
619
620
621
def full_type_name(klass):
    module = klass.__module__
    if module == 'builtins':
        return klass.__qualname__
    return module + '.' + klass.__qualname__

622
623
624
def validate_prompt(prompt):
    outputs = set()
    for x in prompt:
625
626
627
628
629
630
631
632
633
        if 'class_type' not in prompt[x]:
            error = {
                "type": "invalid_prompt",
                "message": f"Cannot execute due to a missing node",
                "details": f"Node ID '#{x}'",
                "extra_info": {}
            }
            return (False, error, [], [])

634
        class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
635
        if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
636
637
638
            outputs.add(x)

    if len(outputs) == 0:
639
640
641
642
643
644
645
        error = {
            "type": "prompt_no_outputs",
            "message": "Prompt has no outputs",
            "details": "",
            "extra_info": {}
        }
        return (False, error, [], [])
646
647
648

    good_outputs = set()
    errors = []
649
    node_errors = {}
650
    validated = {}
651
652
    for o in outputs:
        valid = False
653
        reasons = []
654
        try:
655
            m = validate_inputs(prompt, o, validated)
656
            valid = m[0]
657
658
659
            reasons = m[1]
        except Exception as ex:
            typ, _, tb = sys.exc_info()
660
            valid = False
661
            exception_type = full_type_name(typ)
662
663
664
665
666
            reasons = [{
                "type": "exception_during_validation",
                "message": "Exception when validating node",
                "details": str(ex),
                "extra_info": {
667
                    "exception_type": exception_type,
668
669
670
671
672
673
                    "traceback": traceback.format_tb(tb)
                }
            }]
            validated[o] = (False, reasons, o)

        if valid is True:
674
            good_outputs.add(o)
675
        else:
676
            logging.error(f"Failed to validate prompt for output {o}:")
677
            if len(reasons) > 0:
678
                logging.error("* (prompt):")
679
                for reason in reasons:
680
                    logging.error(f"  - {reason['message']}: {reason['details']}")
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
            errors += [(o, reasons)]
            for node_id, result in validated.items():
                valid = result[0]
                reasons = result[1]
                # If a node upstream has errors, the nodes downstream will also
                # be reported as invalid, but there will be no errors attached.
                # So don't return those nodes as having errors in the response.
                if valid is not True and len(reasons) > 0:
                    if node_id not in node_errors:
                        class_type = prompt[node_id]['class_type']
                        node_errors[node_id] = {
                            "errors": reasons,
                            "dependent_outputs": [],
                            "class_type": class_type
                        }
696
                        logging.error(f"* {class_type} {node_id}:")
697
                        for reason in reasons:
698
                            logging.error(f"  - {reason['message']}: {reason['details']}")
699
                    node_errors[node_id]["dependent_outputs"].append(o)
700
            logging.error("Output will be ignored")
701
702

    if len(good_outputs) == 0:
703
704
705
706
707
708
709
        errors_list = []
        for o, errors in errors:
            for error in errors:
                errors_list.append(f"{error['message']}: {error['details']}")
        errors_list = "\n".join(errors_list)

        error = {
710
711
            "type": "prompt_outputs_failed_validation",
            "message": "Prompt outputs failed validation",
712
713
714
715
716
717
718
            "details": errors_list,
            "extra_info": {}
        }

        return (False, error, list(good_outputs), node_errors)

    return (True, None, list(good_outputs), node_errors)
719

720
MAXIMUM_HISTORY_SIZE = 10000
721
722
723
724
725
726
727
728
729
730

class PromptQueue:
    def __init__(self, server):
        self.server = server
        self.mutex = threading.RLock()
        self.not_empty = threading.Condition(self.mutex)
        self.task_counter = 0
        self.queue = []
        self.currently_running = {}
        self.history = {}
731
        self.flags = {}
732
733
734
735
736
737
738
739
        server.prompt_queue = self

    def put(self, item):
        with self.mutex:
            heapq.heappush(self.queue, item)
            self.server.queue_updated()
            self.not_empty.notify()

740
    def get(self, timeout=None):
741
742
        with self.not_empty:
            while len(self.queue) == 0:
743
744
745
                self.not_empty.wait(timeout=timeout)
                if timeout is not None and len(self.queue) == 0:
                    return None
746
747
748
749
750
751
752
            item = heapq.heappop(self.queue)
            i = self.task_counter
            self.currently_running[i] = copy.deepcopy(item)
            self.task_counter += 1
            self.server.queue_updated()
            return (item, i)

753
754
755
    class ExecutionStatus(NamedTuple):
        status_str: Literal['success', 'error']
        completed: bool
756
        messages: List[str]
757
758
759

    def task_done(self, item_id, outputs,
                  status: Optional['PromptQueue.ExecutionStatus']):
760
761
        with self.mutex:
            prompt = self.currently_running.pop(item_id)
762
763
            if len(self.history) > MAXIMUM_HISTORY_SIZE:
                self.history.pop(next(iter(self.history)))
764

765
            status_dict: Optional[dict] = None
766
767
768
769
770
771
772
773
            if status is not None:
                status_dict = copy.deepcopy(status._asdict())

            self.history[prompt[1]] = {
                "prompt": prompt,
                "outputs": copy.deepcopy(outputs),
                'status': status_dict,
            }
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
            self.server.queue_updated()

    def get_current_queue(self):
        with self.mutex:
            out = []
            for x in self.currently_running.values():
                out += [x]
            return (out, copy.deepcopy(self.queue))

    def get_tasks_remaining(self):
        with self.mutex:
            return len(self.queue) + len(self.currently_running)

    def wipe_queue(self):
        with self.mutex:
            self.queue = []
            self.server.queue_updated()

    def delete_queue_item(self, function):
        with self.mutex:
            for x in range(len(self.queue)):
                if function(self.queue[x]):
                    if len(self.queue) == 1:
                        self.wipe_queue()
                    else:
                        self.queue.pop(x)
                        heapq.heapify(self.queue)
                    self.server.queue_updated()
                    return True
        return False

805
    def get_history(self, prompt_id=None, max_items=None, offset=-1):
806
        with self.mutex:
807
            if prompt_id is None:
808
809
810
811
812
813
814
815
816
817
818
                out = {}
                i = 0
                if offset < 0 and max_items is not None:
                    offset = len(self.history) - max_items
                for k in self.history:
                    if i >= offset:
                        out[k] = self.history[k]
                        if max_items is not None and len(out) >= max_items:
                            break
                    i += 1
                return out
819
820
821
822
            elif prompt_id in self.history:
                return {prompt_id: copy.deepcopy(self.history[prompt_id])}
            else:
                return {}
823
824
825
826
827
828
829
830

    def wipe_history(self):
        with self.mutex:
            self.history = {}

    def delete_history_item(self, id_to_delete):
        with self.mutex:
            self.history.pop(id_to_delete, None)
831
832
833
834
835
836
837
838
839
840
841
842
843
844

    def set_flag(self, name, data):
        with self.mutex:
            self.flags[name] = data
            self.not_empty.notify()

    def get_flags(self, reset=True):
        with self.mutex:
            if reset:
                ret = self.flags
                self.flags = {}
                return ret
            else:
                return self.flags.copy()