execution.py 30.9 KB
Newer Older
1
2
import sys
import copy
3
import logging
4
5
import threading
import heapq
6
import time
7
import traceback
8
import inspect
9
from typing import List, Literal, NamedTuple, Optional
10
11
12
13

import torch
import nodes

14
import comfy.model_management
15

16
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
17
18
19
20
21
22
23
    valid_inputs = class_def.INPUT_TYPES()
    input_data_all = {}
    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
24
            if input_unique_id not in outputs:
25
26
                input_data_all[x] = (None,)
                continue
27
28
29
30
            obj = outputs[input_unique_id][output_index]
            input_data_all[x] = obj
        else:
            if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
31
                input_data_all[x] = [input_data]
32
33
34
35
36

    if "hidden" in valid_inputs:
        h = valid_inputs["hidden"]
        for x in h:
            if h[x] == "PROMPT":
37
                input_data_all[x] = [prompt]
38
            if h[x] == "EXTRA_PNGINFO":
39
                input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
40
            if h[x] == "UNIQUE_ID":
41
                input_data_all[x] = [unique_id]
42
43
    return input_data_all

44
45
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
    # check if node wants the lists
Michael Poutre's avatar
Michael Poutre committed
46
    input_is_list = False
47
    if hasattr(obj, "INPUT_IS_LIST"):
Michael Poutre's avatar
Michael Poutre committed
48
        input_is_list = obj.INPUT_IS_LIST
49

50
51
52
53
    if len(input_data_all) == 0:
        max_len_input = 0
    else:
        max_len_input = max([len(x) for x in input_data_all.values()])
54
55
56
57
58
59
60
61
62
     
    # get a slice of inputs, repeat last input when list isn't long enough
    def slice_dict(d, i):
        d_new = dict()
        for k,v in d.items():
            d_new[k] = v[i if len(v) > i else -1]
        return d_new
    
    results = []
Michael Poutre's avatar
Michael Poutre committed
63
    if input_is_list:
64
65
66
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)(**input_data_all))
67
68
69
70
71
    elif max_len_input == 0:
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)())
    else:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        for i in range(max_len_input):
            if allow_interrupt:
                nodes.before_node_execution()
            results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
    return results

def get_output_data(obj, input_data_all):
    
    results = []
    uis = []
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

    for r in return_values:
        if isinstance(r, dict):
            if 'ui' in r:
                uis.append(r['ui'])
            if 'result' in r:
                results.append(r['result'])
        else:
            results.append(r)
    
    output = []
    if len(results) > 0:
        # check which outputs need concatenating
        output_is_list = [False] * len(results[0])
        if hasattr(obj, "OUTPUT_IS_LIST"):
            output_is_list = obj.OUTPUT_IS_LIST

        # merge node execution results
        for i, is_list in zip(range(len(results[0])), output_is_list):
            if is_list:
                output.append([x for o in results for x in o[i]])
            else:
                output.append([o[i] for o in results])

    ui = dict()    
    if len(uis) > 0:
        ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
    return output, ui

112
def format_value(x):
space-nuko's avatar
space-nuko committed
113
114
115
    if x is None:
        return None
    elif isinstance(x, (int, float, bool, str)):
116
117
118
119
        return x
    else:
        return str(x)

120
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
121
122
123
124
125
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
    if unique_id in outputs:
126
        return (True, None, None)
127
128
129
130
131
132
133
134

    for x in inputs:
        input_data = inputs[x]

        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
135
                result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
136
137
138
139
140
141
142
                if result[0] is not True:
                    # Another node failed further upstream
                    return result

    input_data_all = None
    try:
        input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
143
        if server.client_id is not None:
144
145
            server.last_node_id = unique_id
            server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
146
147
148
149
150

        obj = object_storage.get((unique_id, class_type), None)
        if obj is None:
            obj = class_def()
            object_storage[(unique_id, class_type)] = obj
151
152
153
154
155
156
157
158

        output_data, output_ui = get_output_data(obj, input_data_all)
        outputs[unique_id] = output_data
        if len(output_ui) > 0:
            outputs_ui[unique_id] = output_ui
            if server.client_id is not None:
                server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
    except comfy.model_management.InterruptProcessingException as iex:
159
        logging.info("Processing interrupted")
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

        # skip formatting inputs/outputs
        error_details = {
            "node_id": unique_id,
        }

        return (False, error_details, iex)
    except Exception as ex:
        typ, _, tb = sys.exc_info()
        exception_type = full_type_name(typ)
        input_data_formatted = {}
        if input_data_all is not None:
            input_data_formatted = {}
            for name, inputs in input_data_all.items():
                input_data_formatted[name] = [format_value(x) for x in inputs]

        output_data_formatted = {}
        for node_id, node_outputs in outputs.items():
            output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]

180
        logging.error(f"!!! Exception during processing!!! {ex}")
181
        logging.error(traceback.format_exc())
182
183
184

        error_details = {
            "node_id": unique_id,
space-nuko's avatar
space-nuko committed
185
            "exception_message": str(ex),
186
187
188
189
190
191
192
            "exception_type": exception_type,
            "traceback": traceback.format_tb(tb),
            "current_inputs": input_data_formatted,
            "current_outputs": output_data_formatted
        }
        return (False, error_details, ex)

193
    executed.add(unique_id)
194

195
196
    return (True, None, None)

197
def recursive_will_execute(prompt, outputs, current_item, memo={}):
198
    unique_id = current_item
199
200
201
202

    if unique_id in memo:
        return memo[unique_id]

203
204
205
206
207
208
209
210
211
212
213
    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
214
                will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
215

216
217
    memo[unique_id] = will_execute + [unique_id]
    return memo[unique_id]
218
219
220
221
222
223
224
225
226

def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]

    is_changed_old = ''
    is_changed = ''
227
    to_delete = False
228
229
230
231
    if hasattr(class_def, 'IS_CHANGED'):
        if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
            is_changed_old = old_prompt[unique_id]['is_changed']
        if 'is_changed' not in prompt[unique_id]:
232
            input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
233
            if input_data_all is not None:
234
                try:
235
236
                    #is_changed = class_def.IS_CHANGED(**input_data_all)
                    is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
237
238
239
                    prompt[unique_id]['is_changed'] = is_changed
                except:
                    to_delete = True
240
241
242
243
244
245
        else:
            is_changed = prompt[unique_id]['is_changed']

    if unique_id not in outputs:
        return True

246
247
248
249
250
    if not to_delete:
        if is_changed != is_changed_old:
            to_delete = True
        elif unique_id not in old_prompt:
            to_delete = True
251
252
        elif class_type != old_prompt[unique_id]['class_type']:
            to_delete = True
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        elif inputs == old_prompt[unique_id]['inputs']:
            for x in inputs:
                input_data = inputs[x]

                if isinstance(input_data, list):
                    input_unique_id = input_data[0]
                    output_index = input_data[1]
                    if input_unique_id in outputs:
                        to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
                    else:
                        to_delete = True
                    if to_delete:
                        break
        else:
            to_delete = True
268
269
270
271
272
273
274
275

    if to_delete:
        d = outputs.pop(unique_id)
        del d
    return to_delete

class PromptExecutor:
    def __init__(self, server):
276
277
278
279
        self.server = server
        self.reset()

    def reset(self):
280
        self.outputs = {}
281
        self.object_storage = {}
282
        self.outputs_ui = {}
283
        self.status_messages = []
284
        self.success = True
285
286
        self.old_prompt = {}

287
288
289
290
291
    def add_message(self, event, data: dict, broadcast: bool):
        data = {
            **data,
            "timestamp": int(time.time() * 1000),
        }
292
        self.status_messages.append((event, data))
293
294
295
        if self.server.client_id is not None or broadcast:
            self.server.send_sync(event, data, self.server.client_id)

296
297
298
299
    def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
        node_id = error["node_id"]
        class_type = prompt[node_id]["class_type"]

300
301
302
303
304
        # First, send back the status to the frontend depending
        # on the exception type
        if isinstance(ex, comfy.model_management.InterruptProcessingException):
            mes = {
                "prompt_id": prompt_id,
305
306
                "node_id": node_id,
                "node_type": class_type,
307
308
                "executed": list(executed),
            }
309
            self.add_message("execution_interrupted", mes, broadcast=True)
310
        else:
311
312
313
314
315
            mes = {
                "prompt_id": prompt_id,
                "node_id": node_id,
                "node_type": class_type,
                "executed": list(executed),
316

317
318
319
320
321
322
                "exception_message": error["exception_message"],
                "exception_type": error["exception_type"],
                "traceback": error["traceback"],
                "current_inputs": error["current_inputs"],
                "current_outputs": error["current_outputs"],
            }
323
            self.add_message("execution_error", mes, broadcast=False)
324
        
325
326
327
328
329
330
331
332
333
334
335
336
        # Next, remove the subsequent outputs since they will not be executed
        to_delete = []
        for o in self.outputs:
            if (o not in current_outputs) and (o not in executed):
                to_delete += [o]
                if o in self.old_prompt:
                    d = self.old_prompt.pop(o)
                    del d
        for o in to_delete:
            d = self.outputs.pop(o)
            del d

337
    def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
338
339
        nodes.interrupt_processing(False)

340
341
342
343
344
        if "client_id" in extra_data:
            self.server.client_id = extra_data["client_id"]
        else:
            self.server.client_id = None

345
346
        self.status_messages = []
        self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
347

348
        with torch.inference_mode():
349
350
351
352
353
354
355
356
            #delete cached outputs if nodes don't exist for them
            to_delete = []
            for o in self.outputs:
                if o not in prompt:
                    to_delete += [o]
            for o in to_delete:
                d = self.outputs.pop(o)
                del d
357
358
359
360
361
362
363
364
365
366
367
            to_delete = []
            for o in self.object_storage:
                if o[0] not in prompt:
                    to_delete += [o]
                else:
                    p = prompt[o[0]]
                    if o[1] != p['class_type']:
                        to_delete += [o]
            for o in to_delete:
                d = self.object_storage.pop(o)
                del d
368

369
370
371
372
            for x in prompt:
                recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)

            current_outputs = set(self.outputs.keys())
373
374
375
376
377
            for x in list(self.outputs_ui.keys()):
                if x not in current_outputs:
                    d = self.outputs_ui.pop(x)
                    del d

378
            comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
379
            self.add_message("execution_cached",
380
381
                          { "nodes": list(current_outputs) , "prompt_id": prompt_id},
                          broadcast=False)
382
            executed = set()
383
384
385
386
387
388
389
390
            output_node_id = None
            to_execute = []

            for node_id in list(execute_outputs):
                to_execute += [(0, node_id)]

            while len(to_execute) > 0:
                #always execute the output that depends on the least amount of unexecuted nodes first
391
392
                memo = {}
                to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
393
394
395
396
397
                output_node_id = to_execute.pop(0)[-1]

                # This call shouldn't raise anything if there's an error deep in
                # the actual SD code, instead it will report the node where the
                # error was raised
398
399
                self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
                if self.success is not True:
400
                    self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
401
                    break
402
403
404
            else:
                # Only execute when the while-loop ends without break
                self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
405
406
407
408

            for x in executed:
                self.old_prompt[x] = copy.deepcopy(prompt[x])
            self.server.last_node_id = None
409
410
            if comfy.model_management.DISABLE_SMART_MEMORY:
                comfy.model_management.unload_all_models()
411

412

413

414
def validate_inputs(prompt, item, validated):
415
    unique_id = item
416
417
418
    if unique_id in validated:
        return validated[unique_id]

419
420
421
422
423
424
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]

    class_inputs = obj_class.INPUT_TYPES()
    required_inputs = class_inputs['required']
425
426
427
428

    errors = []
    valid = True

429
430
431
432
    validate_function_inputs = []
    if hasattr(obj_class, "VALIDATE_INPUTS"):
        validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args

433
434
    for x in required_inputs:
        if x not in inputs:
435
436
437
438
439
440
441
442
443
444
445
            error = {
                "type": "required_input_missing",
                "message": "Required input is missing",
                "details": f"{x}",
                "extra_info": {
                    "input_name": x
                }
            }
            errors.append(error)
            continue

446
447
448
449
450
        val = inputs[x]
        info = required_inputs[x]
        type_input = info[0]
        if isinstance(val, list):
            if len(val) != 2:
451
452
453
454
455
456
457
458
459
460
461
462
463
                error = {
                    "type": "bad_linked_input",
                    "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
                    "details": f"{x}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val
                    }
                }
                errors.append(error)
                continue

464
465
466
467
            o_id = val[0]
            o_class_type = prompt[o_id]['class_type']
            r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
            if r[val[1]] != type_input:
468
469
470
471
472
473
474
475
476
                received_type = r[val[1]]
                details = f"{x}, {received_type} != {type_input}"
                error = {
                    "type": "return_type_mismatch",
                    "message": "Return type mismatch between linked nodes",
                    "details": details,
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
477
478
                        "received_type": received_type,
                        "linked_node": val
479
480
481
482
483
484
485
486
487
488
489
490
491
                    }
                }
                errors.append(error)
                continue
            try:
                r = validate_inputs(prompt, o_id, validated)
                if r[0] is False:
                    # `r` will be set in `validated[o_id]` already
                    valid = False
                    continue
            except Exception as ex:
                typ, _, tb = sys.exc_info()
                valid = False
492
                exception_type = full_type_name(typ)
493
                reasons = [{
space-nuko's avatar
space-nuko committed
494
495
                    "type": "exception_during_inner_validation",
                    "message": "Exception when validating inner node",
496
497
                    "details": str(ex),
                    "extra_info": {
space-nuko's avatar
space-nuko committed
498
499
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
500
                        "exception_message": str(ex),
501
                        "exception_type": exception_type,
space-nuko's avatar
space-nuko committed
502
503
                        "traceback": traceback.format_tb(tb),
                        "linked_node": val
504
505
506
507
                    }
                }]
                validated[o_id] = (False, reasons, o_id)
                continue
508
        else:
space-nuko's avatar
space-nuko committed
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            try:
                if type_input == "INT":
                    val = int(val)
                    inputs[x] = val
                if type_input == "FLOAT":
                    val = float(val)
                    inputs[x] = val
                if type_input == "STRING":
                    val = str(val)
                    inputs[x] = val
            except Exception as ex:
                error = {
                    "type": "invalid_input_type",
                    "message": f"Failed to convert an input value to a {type_input} value",
                    "details": f"{x}, {val}, {ex}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val,
                        "exception_message": str(ex)
                    }
                }
                errors.append(error)
                continue
533
534
535

            if len(info) > 1:
                if "min" in info[1] and val < info[1]["min"]:
536
537
538
539
540
541
542
543
544
545
546
547
                    error = {
                        "type": "value_smaller_than_min",
                        "message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
548
                if "max" in info[1] and val > info[1]["max"]:
549
550
551
552
553
554
555
556
557
558
559
560
                    error = {
                        "type": "value_bigger_than_max",
                        "message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
561

562
            if x not in validate_function_inputs:
563
564
                if isinstance(type_input, list):
                    if val not in type_input:
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
                        input_config = info
                        list_info = ""

                        # Don't send back gigantic lists like if they're lots of
                        # scanned model filepaths
                        if len(type_input) > 20:
                            list_info = f"(list of length {len(type_input)})"
                            input_config = None
                        else:
                            list_info = str(type_input)

                        error = {
                            "type": "value_not_in_list",
                            "message": "Value not in list",
                            "details": f"{x}: '{val}' not in {list_info}",
                            "extra_info": {
                                "input_name": x,
                                "input_config": input_config,
                                "received_value": val,
                            }
                        }
                        errors.append(error)
                        continue

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    if len(validate_function_inputs) > 0:
        input_data_all = get_input_data(inputs, obj_class, unique_id)
        input_filtered = {}
        for x in input_data_all:
            if x in validate_function_inputs:
                input_filtered[x] = input_data_all[x]

        #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
        ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
        for x in input_filtered:
            for i, r in enumerate(ret):
                if r is not True:
                    details = f"{x}"
                    if r is not False:
                        details += f" - {str(r)}"

                    error = {
                        "type": "custom_validation_failed",
                        "message": "Custom validation failed for node",
                        "details": details,
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue

618
619
620
621
    if len(errors) > 0 or valid is not True:
        ret = (False, errors, unique_id)
    else:
        ret = (True, [], unique_id)
622
623
624

    validated[unique_id] = ret
    return ret
625

626
627
628
629
630
631
def full_type_name(klass):
    module = klass.__module__
    if module == 'builtins':
        return klass.__qualname__
    return module + '.' + klass.__qualname__

632
633
634
def validate_prompt(prompt):
    outputs = set()
    for x in prompt:
635
636
637
        if 'class_type' not in prompt[x]:
            error = {
                "type": "invalid_prompt",
638
639
640
641
642
643
644
645
646
647
648
649
                "message": f"Cannot execute because a node is missing the class_type property.",
                "details": f"Node ID '#{x}'",
                "extra_info": {}
            }
            return (False, error, [], [])

        class_type = prompt[x]['class_type']
        class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
        if class_ is None:
            error = {
                "type": "invalid_prompt",
                "message": f"Cannot execute because node {class_type} does not exist.",
650
651
652
653
654
655
                "details": f"Node ID '#{x}'",
                "extra_info": {}
            }
            return (False, error, [], [])

        if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
656
657
658
            outputs.add(x)

    if len(outputs) == 0:
659
660
661
662
663
664
665
        error = {
            "type": "prompt_no_outputs",
            "message": "Prompt has no outputs",
            "details": "",
            "extra_info": {}
        }
        return (False, error, [], [])
666
667
668

    good_outputs = set()
    errors = []
669
    node_errors = {}
670
    validated = {}
671
672
    for o in outputs:
        valid = False
673
        reasons = []
674
        try:
675
            m = validate_inputs(prompt, o, validated)
676
            valid = m[0]
677
678
679
            reasons = m[1]
        except Exception as ex:
            typ, _, tb = sys.exc_info()
680
            valid = False
681
            exception_type = full_type_name(typ)
682
683
684
685
686
            reasons = [{
                "type": "exception_during_validation",
                "message": "Exception when validating node",
                "details": str(ex),
                "extra_info": {
687
                    "exception_type": exception_type,
688
689
690
691
692
693
                    "traceback": traceback.format_tb(tb)
                }
            }]
            validated[o] = (False, reasons, o)

        if valid is True:
694
            good_outputs.add(o)
695
        else:
696
            logging.error(f"Failed to validate prompt for output {o}:")
697
            if len(reasons) > 0:
698
                logging.error("* (prompt):")
699
                for reason in reasons:
700
                    logging.error(f"  - {reason['message']}: {reason['details']}")
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
            errors += [(o, reasons)]
            for node_id, result in validated.items():
                valid = result[0]
                reasons = result[1]
                # If a node upstream has errors, the nodes downstream will also
                # be reported as invalid, but there will be no errors attached.
                # So don't return those nodes as having errors in the response.
                if valid is not True and len(reasons) > 0:
                    if node_id not in node_errors:
                        class_type = prompt[node_id]['class_type']
                        node_errors[node_id] = {
                            "errors": reasons,
                            "dependent_outputs": [],
                            "class_type": class_type
                        }
716
                        logging.error(f"* {class_type} {node_id}:")
717
                        for reason in reasons:
718
                            logging.error(f"  - {reason['message']}: {reason['details']}")
719
                    node_errors[node_id]["dependent_outputs"].append(o)
720
            logging.error("Output will be ignored")
721
722

    if len(good_outputs) == 0:
723
724
725
726
727
728
729
        errors_list = []
        for o, errors in errors:
            for error in errors:
                errors_list.append(f"{error['message']}: {error['details']}")
        errors_list = "\n".join(errors_list)

        error = {
730
731
            "type": "prompt_outputs_failed_validation",
            "message": "Prompt outputs failed validation",
732
733
734
735
736
737
738
            "details": errors_list,
            "extra_info": {}
        }

        return (False, error, list(good_outputs), node_errors)

    return (True, None, list(good_outputs), node_errors)
739

740
MAXIMUM_HISTORY_SIZE = 10000
741
742
743
744
745
746
747
748
749
750

class PromptQueue:
    def __init__(self, server):
        self.server = server
        self.mutex = threading.RLock()
        self.not_empty = threading.Condition(self.mutex)
        self.task_counter = 0
        self.queue = []
        self.currently_running = {}
        self.history = {}
751
        self.flags = {}
752
753
754
755
756
757
758
759
        server.prompt_queue = self

    def put(self, item):
        with self.mutex:
            heapq.heappush(self.queue, item)
            self.server.queue_updated()
            self.not_empty.notify()

760
    def get(self, timeout=None):
761
762
        with self.not_empty:
            while len(self.queue) == 0:
763
764
765
                self.not_empty.wait(timeout=timeout)
                if timeout is not None and len(self.queue) == 0:
                    return None
766
767
768
769
770
771
772
            item = heapq.heappop(self.queue)
            i = self.task_counter
            self.currently_running[i] = copy.deepcopy(item)
            self.task_counter += 1
            self.server.queue_updated()
            return (item, i)

773
774
775
    class ExecutionStatus(NamedTuple):
        status_str: Literal['success', 'error']
        completed: bool
776
        messages: List[str]
777
778
779

    def task_done(self, item_id, outputs,
                  status: Optional['PromptQueue.ExecutionStatus']):
780
781
        with self.mutex:
            prompt = self.currently_running.pop(item_id)
782
783
            if len(self.history) > MAXIMUM_HISTORY_SIZE:
                self.history.pop(next(iter(self.history)))
784

785
            status_dict: Optional[dict] = None
786
787
788
789
790
791
792
793
            if status is not None:
                status_dict = copy.deepcopy(status._asdict())

            self.history[prompt[1]] = {
                "prompt": prompt,
                "outputs": copy.deepcopy(outputs),
                'status': status_dict,
            }
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
            self.server.queue_updated()

    def get_current_queue(self):
        with self.mutex:
            out = []
            for x in self.currently_running.values():
                out += [x]
            return (out, copy.deepcopy(self.queue))

    def get_tasks_remaining(self):
        with self.mutex:
            return len(self.queue) + len(self.currently_running)

    def wipe_queue(self):
        with self.mutex:
            self.queue = []
            self.server.queue_updated()

    def delete_queue_item(self, function):
        with self.mutex:
            for x in range(len(self.queue)):
                if function(self.queue[x]):
                    if len(self.queue) == 1:
                        self.wipe_queue()
                    else:
                        self.queue.pop(x)
                        heapq.heapify(self.queue)
                    self.server.queue_updated()
                    return True
        return False

825
    def get_history(self, prompt_id=None, max_items=None, offset=-1):
826
        with self.mutex:
827
            if prompt_id is None:
828
829
830
831
832
833
834
835
836
837
838
                out = {}
                i = 0
                if offset < 0 and max_items is not None:
                    offset = len(self.history) - max_items
                for k in self.history:
                    if i >= offset:
                        out[k] = self.history[k]
                        if max_items is not None and len(out) >= max_items:
                            break
                    i += 1
                return out
839
840
841
842
            elif prompt_id in self.history:
                return {prompt_id: copy.deepcopy(self.history[prompt_id])}
            else:
                return {}
843
844
845
846
847
848
849
850

    def wipe_history(self):
        with self.mutex:
            self.history = {}

    def delete_history_item(self, id_to_delete):
        with self.mutex:
            self.history.pop(id_to_delete, None)
851
852
853
854
855
856
857
858
859
860
861
862
863
864

    def set_flag(self, name, data):
        with self.mutex:
            self.flags[name] = data
            self.not_empty.notify()

    def get_flags(self, reset=True):
        with self.mutex:
            if reset:
                ret = self.flags
                self.flags = {}
                return ret
            else:
                return self.flags.copy()