"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "9ee0feae3249644327b0f350c4af04319a2f5764"
execution.py 29.2 KB
Newer Older
1
2
3
4
import os
import sys
import copy
import json
5
import logging
6
7
8
import threading
import heapq
import traceback
9
import gc
10
import inspect
11
12
13
14

import torch
import nodes

15
import comfy.model_management
16

17
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
18
19
20
21
22
23
24
    valid_inputs = class_def.INPUT_TYPES()
    input_data_all = {}
    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
25
            if input_unique_id not in outputs:
26
27
                input_data_all[x] = (None,)
                continue
28
29
30
31
            obj = outputs[input_unique_id][output_index]
            input_data_all[x] = obj
        else:
            if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
32
                input_data_all[x] = [input_data]
33
34
35
36
37

    if "hidden" in valid_inputs:
        h = valid_inputs["hidden"]
        for x in h:
            if h[x] == "PROMPT":
38
                input_data_all[x] = [prompt]
39
40
            if h[x] == "EXTRA_PNGINFO":
                if "extra_pnginfo" in extra_data:
41
                    input_data_all[x] = [extra_data['extra_pnginfo']]
42
            if h[x] == "UNIQUE_ID":
43
                input_data_all[x] = [unique_id]
44
45
    return input_data_all

46
47
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
    # check if node wants the lists
Michael Poutre's avatar
Michael Poutre committed
48
    input_is_list = False
49
    if hasattr(obj, "INPUT_IS_LIST"):
Michael Poutre's avatar
Michael Poutre committed
50
        input_is_list = obj.INPUT_IS_LIST
51

52
53
54
55
    if len(input_data_all) == 0:
        max_len_input = 0
    else:
        max_len_input = max([len(x) for x in input_data_all.values()])
56
57
58
59
60
61
62
63
64
     
    # get a slice of inputs, repeat last input when list isn't long enough
    def slice_dict(d, i):
        d_new = dict()
        for k,v in d.items():
            d_new[k] = v[i if len(v) > i else -1]
        return d_new
    
    results = []
Michael Poutre's avatar
Michael Poutre committed
65
    if input_is_list:
66
67
68
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)(**input_data_all))
69
70
71
72
73
    elif max_len_input == 0:
        if allow_interrupt:
            nodes.before_node_execution()
        results.append(getattr(obj, func)())
    else:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        for i in range(max_len_input):
            if allow_interrupt:
                nodes.before_node_execution()
            results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
    return results

def get_output_data(obj, input_data_all):
    
    results = []
    uis = []
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

    for r in return_values:
        if isinstance(r, dict):
            if 'ui' in r:
                uis.append(r['ui'])
            if 'result' in r:
                results.append(r['result'])
        else:
            results.append(r)
    
    output = []
    if len(results) > 0:
        # check which outputs need concatenating
        output_is_list = [False] * len(results[0])
        if hasattr(obj, "OUTPUT_IS_LIST"):
            output_is_list = obj.OUTPUT_IS_LIST

        # merge node execution results
        for i, is_list in zip(range(len(results[0])), output_is_list):
            if is_list:
                output.append([x for o in results for x in o[i]])
            else:
                output.append([o[i] for o in results])

    ui = dict()    
    if len(uis) > 0:
        ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
    return output, ui

114
def format_value(x):
space-nuko's avatar
space-nuko committed
115
116
117
    if x is None:
        return None
    elif isinstance(x, (int, float, bool, str)):
118
119
120
121
        return x
    else:
        return str(x)

122
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
123
124
125
126
127
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
    if unique_id in outputs:
128
        return (True, None, None)
129
130
131
132
133
134
135
136

    for x in inputs:
        input_data = inputs[x]

        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
137
                result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
138
139
140
141
142
143
144
                if result[0] is not True:
                    # Another node failed further upstream
                    return result

    input_data_all = None
    try:
        input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
145
        if server.client_id is not None:
146
147
            server.last_node_id = unique_id
            server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
148
149
150
151
152

        obj = object_storage.get((unique_id, class_type), None)
        if obj is None:
            obj = class_def()
            object_storage[(unique_id, class_type)] = obj
153
154
155
156
157
158
159
160

        output_data, output_ui = get_output_data(obj, input_data_all)
        outputs[unique_id] = output_data
        if len(output_ui) > 0:
            outputs_ui[unique_id] = output_ui
            if server.client_id is not None:
                server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
    except comfy.model_management.InterruptProcessingException as iex:
161
        logging.info("Processing interrupted")
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

        # skip formatting inputs/outputs
        error_details = {
            "node_id": unique_id,
        }

        return (False, error_details, iex)
    except Exception as ex:
        typ, _, tb = sys.exc_info()
        exception_type = full_type_name(typ)
        input_data_formatted = {}
        if input_data_all is not None:
            input_data_formatted = {}
            for name, inputs in input_data_all.items():
                input_data_formatted[name] = [format_value(x) for x in inputs]

        output_data_formatted = {}
        for node_id, node_outputs in outputs.items():
            output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]

182
183
        logging.error("!!! Exception during processing !!!")
        logging.error(traceback.format_exc())
184
185
186

        error_details = {
            "node_id": unique_id,
space-nuko's avatar
space-nuko committed
187
            "exception_message": str(ex),
188
189
190
191
192
193
194
            "exception_type": exception_type,
            "traceback": traceback.format_tb(tb),
            "current_inputs": input_data_formatted,
            "current_outputs": output_data_formatted
        }
        return (False, error_details, ex)

195
    executed.add(unique_id)
196

197
198
    return (True, None, None)

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def recursive_will_execute(prompt, outputs, current_item):
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    will_execute = []
    if unique_id in outputs:
        return []

    for x in inputs:
        input_data = inputs[x]
        if isinstance(input_data, list):
            input_unique_id = input_data[0]
            output_index = input_data[1]
            if input_unique_id not in outputs:
                will_execute += recursive_will_execute(prompt, outputs, input_unique_id)

    return will_execute + [unique_id]

def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
    unique_id = current_item
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    class_def = nodes.NODE_CLASS_MAPPINGS[class_type]

    is_changed_old = ''
    is_changed = ''
224
    to_delete = False
225
226
227
228
    if hasattr(class_def, 'IS_CHANGED'):
        if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
            is_changed_old = old_prompt[unique_id]['is_changed']
        if 'is_changed' not in prompt[unique_id]:
229
            input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
230
            if input_data_all is not None:
231
                try:
232
233
                    #is_changed = class_def.IS_CHANGED(**input_data_all)
                    is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
234
235
236
                    prompt[unique_id]['is_changed'] = is_changed
                except:
                    to_delete = True
237
238
239
240
241
242
        else:
            is_changed = prompt[unique_id]['is_changed']

    if unique_id not in outputs:
        return True

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    if not to_delete:
        if is_changed != is_changed_old:
            to_delete = True
        elif unique_id not in old_prompt:
            to_delete = True
        elif inputs == old_prompt[unique_id]['inputs']:
            for x in inputs:
                input_data = inputs[x]

                if isinstance(input_data, list):
                    input_unique_id = input_data[0]
                    output_index = input_data[1]
                    if input_unique_id in outputs:
                        to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
                    else:
                        to_delete = True
                    if to_delete:
                        break
        else:
            to_delete = True
263
264
265
266
267
268
269
270

    if to_delete:
        d = outputs.pop(unique_id)
        del d
    return to_delete

class PromptExecutor:
    def __init__(self, server):
271
272
273
274
        self.server = server
        self.reset()

    def reset(self):
275
        self.outputs = {}
276
        self.object_storage = {}
277
        self.outputs_ui = {}
278
279
        self.old_prompt = {}

280
281
282
283
    def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
        node_id = error["node_id"]
        class_type = prompt[node_id]["class_type"]

284
285
286
287
288
        # First, send back the status to the frontend depending
        # on the exception type
        if isinstance(ex, comfy.model_management.InterruptProcessingException):
            mes = {
                "prompt_id": prompt_id,
289
290
                "node_id": node_id,
                "node_type": class_type,
291
292
293
294
295
296
297
                "executed": list(executed),
            }
            self.server.send_sync("execution_interrupted", mes, self.server.client_id)
        else:
            if self.server.client_id is not None:
                mes = {
                    "prompt_id": prompt_id,
298
299
                    "node_id": node_id,
                    "node_type": class_type,
300
301
                    "executed": list(executed),

space-nuko's avatar
space-nuko committed
302
                    "exception_message": error["exception_message"],
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                    "exception_type": error["exception_type"],
                    "traceback": error["traceback"],
                    "current_inputs": error["current_inputs"],
                    "current_outputs": error["current_outputs"],
                }
                self.server.send_sync("execution_error", mes, self.server.client_id)

        # Next, remove the subsequent outputs since they will not be executed
        to_delete = []
        for o in self.outputs:
            if (o not in current_outputs) and (o not in executed):
                to_delete += [o]
                if o in self.old_prompt:
                    d = self.old_prompt.pop(o)
                    del d
        for o in to_delete:
            d = self.outputs.pop(o)
            del d

322
    def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
323
324
        nodes.interrupt_processing(False)

325
326
327
328
329
        if "client_id" in extra_data:
            self.server.client_id = extra_data["client_id"]
        else:
            self.server.client_id = None

330
331
332
        if self.server.client_id is not None:
            self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)

333
        with torch.inference_mode():
334
335
336
337
338
339
340
341
            #delete cached outputs if nodes don't exist for them
            to_delete = []
            for o in self.outputs:
                if o not in prompt:
                    to_delete += [o]
            for o in to_delete:
                d = self.outputs.pop(o)
                del d
342
343
344
345
346
347
348
349
350
351
352
            to_delete = []
            for o in self.object_storage:
                if o[0] not in prompt:
                    to_delete += [o]
                else:
                    p = prompt[o[0]]
                    if o[1] != p['class_type']:
                        to_delete += [o]
            for o in to_delete:
                d = self.object_storage.pop(o)
                del d
353

354
355
356
357
            for x in prompt:
                recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)

            current_outputs = set(self.outputs.keys())
358
359
360
361
362
            for x in list(self.outputs_ui.keys()):
                if x not in current_outputs:
                    d = self.outputs_ui.pop(x)
                    del d

comfyanonymous's avatar
comfyanonymous committed
363
            comfy.model_management.cleanup_models()
364
            if self.server.client_id is not None:
365
                self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
366
            executed = set()
367
368
369
370
371
372
373
374
375
376
377
378
379
380
            output_node_id = None
            to_execute = []

            for node_id in list(execute_outputs):
                to_execute += [(0, node_id)]

            while len(to_execute) > 0:
                #always execute the output that depends on the least amount of unexecuted nodes first
                to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
                output_node_id = to_execute.pop(0)[-1]

                # This call shouldn't raise anything if there's an error deep in
                # the actual SD code, instead it will report the node where the
                # error was raised
381
                success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
382
                if success is not True:
383
                    self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
384
                    break
385
386
387
388

            for x in executed:
                self.old_prompt[x] = copy.deepcopy(prompt[x])
            self.server.last_node_id = None
389
390
            if comfy.model_management.DISABLE_SMART_MEMORY:
                comfy.model_management.unload_all_models()
391

392

393

394
def validate_inputs(prompt, item, validated):
395
    unique_id = item
396
397
398
    if unique_id in validated:
        return validated[unique_id]

399
400
401
402
403
404
    inputs = prompt[unique_id]['inputs']
    class_type = prompt[unique_id]['class_type']
    obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]

    class_inputs = obj_class.INPUT_TYPES()
    required_inputs = class_inputs['required']
405
406
407
408

    errors = []
    valid = True

409
410
411
412
    validate_function_inputs = []
    if hasattr(obj_class, "VALIDATE_INPUTS"):
        validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args

413
414
    for x in required_inputs:
        if x not in inputs:
415
416
417
418
419
420
421
422
423
424
425
            error = {
                "type": "required_input_missing",
                "message": "Required input is missing",
                "details": f"{x}",
                "extra_info": {
                    "input_name": x
                }
            }
            errors.append(error)
            continue

426
427
428
429
430
        val = inputs[x]
        info = required_inputs[x]
        type_input = info[0]
        if isinstance(val, list):
            if len(val) != 2:
431
432
433
434
435
436
437
438
439
440
441
442
443
                error = {
                    "type": "bad_linked_input",
                    "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]",
                    "details": f"{x}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val
                    }
                }
                errors.append(error)
                continue

444
445
446
447
            o_id = val[0]
            o_class_type = prompt[o_id]['class_type']
            r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
            if r[val[1]] != type_input:
448
449
450
451
452
453
454
455
456
                received_type = r[val[1]]
                details = f"{x}, {received_type} != {type_input}"
                error = {
                    "type": "return_type_mismatch",
                    "message": "Return type mismatch between linked nodes",
                    "details": details,
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
457
458
                        "received_type": received_type,
                        "linked_node": val
459
460
461
462
463
464
465
466
467
468
469
470
471
                    }
                }
                errors.append(error)
                continue
            try:
                r = validate_inputs(prompt, o_id, validated)
                if r[0] is False:
                    # `r` will be set in `validated[o_id]` already
                    valid = False
                    continue
            except Exception as ex:
                typ, _, tb = sys.exc_info()
                valid = False
472
                exception_type = full_type_name(typ)
473
                reasons = [{
space-nuko's avatar
space-nuko committed
474
475
                    "type": "exception_during_inner_validation",
                    "message": "Exception when validating inner node",
476
477
                    "details": str(ex),
                    "extra_info": {
space-nuko's avatar
space-nuko committed
478
479
                        "input_name": x,
                        "input_config": info,
space-nuko's avatar
space-nuko committed
480
                        "exception_message": str(ex),
481
                        "exception_type": exception_type,
space-nuko's avatar
space-nuko committed
482
483
                        "traceback": traceback.format_tb(tb),
                        "linked_node": val
484
485
486
487
                    }
                }]
                validated[o_id] = (False, reasons, o_id)
                continue
488
        else:
space-nuko's avatar
space-nuko committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            try:
                if type_input == "INT":
                    val = int(val)
                    inputs[x] = val
                if type_input == "FLOAT":
                    val = float(val)
                    inputs[x] = val
                if type_input == "STRING":
                    val = str(val)
                    inputs[x] = val
            except Exception as ex:
                error = {
                    "type": "invalid_input_type",
                    "message": f"Failed to convert an input value to a {type_input} value",
                    "details": f"{x}, {val}, {ex}",
                    "extra_info": {
                        "input_name": x,
                        "input_config": info,
                        "received_value": val,
                        "exception_message": str(ex)
                    }
                }
                errors.append(error)
                continue
513
514
515

            if len(info) > 1:
                if "min" in info[1] and val < info[1]["min"]:
516
517
518
519
520
521
522
523
524
525
526
527
                    error = {
                        "type": "value_smaller_than_min",
                        "message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
528
                if "max" in info[1] and val > info[1]["max"]:
529
530
531
532
533
534
535
536
537
538
539
540
                    error = {
                        "type": "value_bigger_than_max",
                        "message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
                        "details": f"{x}",
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue
541

542
            if x not in validate_function_inputs:
543
544
                if isinstance(type_input, list):
                    if val not in type_input:
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
                        input_config = info
                        list_info = ""

                        # Don't send back gigantic lists like if they're lots of
                        # scanned model filepaths
                        if len(type_input) > 20:
                            list_info = f"(list of length {len(type_input)})"
                            input_config = None
                        else:
                            list_info = str(type_input)

                        error = {
                            "type": "value_not_in_list",
                            "message": "Value not in list",
                            "details": f"{x}: '{val}' not in {list_info}",
                            "extra_info": {
                                "input_name": x,
                                "input_config": input_config,
                                "received_value": val,
                            }
                        }
                        errors.append(error)
                        continue

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    if len(validate_function_inputs) > 0:
        input_data_all = get_input_data(inputs, obj_class, unique_id)
        input_filtered = {}
        for x in input_data_all:
            if x in validate_function_inputs:
                input_filtered[x] = input_data_all[x]

        #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
        ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
        for x in input_filtered:
            for i, r in enumerate(ret):
                if r is not True:
                    details = f"{x}"
                    if r is not False:
                        details += f" - {str(r)}"

                    error = {
                        "type": "custom_validation_failed",
                        "message": "Custom validation failed for node",
                        "details": details,
                        "extra_info": {
                            "input_name": x,
                            "input_config": info,
                            "received_value": val,
                        }
                    }
                    errors.append(error)
                    continue

598
599
600
601
    if len(errors) > 0 or valid is not True:
        ret = (False, errors, unique_id)
    else:
        ret = (True, [], unique_id)
602
603
604

    validated[unique_id] = ret
    return ret
605

606
607
608
609
610
611
def full_type_name(klass):
    module = klass.__module__
    if module == 'builtins':
        return klass.__qualname__
    return module + '.' + klass.__qualname__

612
613
614
615
616
617
618
619
def validate_prompt(prompt):
    outputs = set()
    for x in prompt:
        class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
        if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE == True:
            outputs.add(x)

    if len(outputs) == 0:
620
621
622
623
624
625
626
        error = {
            "type": "prompt_no_outputs",
            "message": "Prompt has no outputs",
            "details": "",
            "extra_info": {}
        }
        return (False, error, [], [])
627
628
629

    good_outputs = set()
    errors = []
630
    node_errors = {}
631
    validated = {}
632
633
    for o in outputs:
        valid = False
634
        reasons = []
635
        try:
636
            m = validate_inputs(prompt, o, validated)
637
            valid = m[0]
638
639
640
            reasons = m[1]
        except Exception as ex:
            typ, _, tb = sys.exc_info()
641
            valid = False
642
            exception_type = full_type_name(typ)
643
644
645
646
647
            reasons = [{
                "type": "exception_during_validation",
                "message": "Exception when validating node",
                "details": str(ex),
                "extra_info": {
648
                    "exception_type": exception_type,
649
650
651
652
653
654
                    "traceback": traceback.format_tb(tb)
                }
            }]
            validated[o] = (False, reasons, o)

        if valid is True:
655
            good_outputs.add(o)
656
        else:
657
            logging.error(f"Failed to validate prompt for output {o}:")
658
            if len(reasons) > 0:
659
                logging.error("* (prompt):")
660
                for reason in reasons:
661
                    logging.error(f"  - {reason['message']}: {reason['details']}")
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
            errors += [(o, reasons)]
            for node_id, result in validated.items():
                valid = result[0]
                reasons = result[1]
                # If a node upstream has errors, the nodes downstream will also
                # be reported as invalid, but there will be no errors attached.
                # So don't return those nodes as having errors in the response.
                if valid is not True and len(reasons) > 0:
                    if node_id not in node_errors:
                        class_type = prompt[node_id]['class_type']
                        node_errors[node_id] = {
                            "errors": reasons,
                            "dependent_outputs": [],
                            "class_type": class_type
                        }
677
                        logging.error(f"* {class_type} {node_id}:")
678
                        for reason in reasons:
679
                            logging.error(f"  - {reason['message']}: {reason['details']}")
680
                    node_errors[node_id]["dependent_outputs"].append(o)
681
            logging.error("Output will be ignored")
682
683

    if len(good_outputs) == 0:
684
685
686
687
688
689
690
        errors_list = []
        for o, errors in errors:
            for error in errors:
                errors_list.append(f"{error['message']}: {error['details']}")
        errors_list = "\n".join(errors_list)

        error = {
691
692
            "type": "prompt_outputs_failed_validation",
            "message": "Prompt outputs failed validation",
693
694
695
696
697
698
699
            "details": errors_list,
            "extra_info": {}
        }

        return (False, error, list(good_outputs), node_errors)

    return (True, None, list(good_outputs), node_errors)
700

701
MAXIMUM_HISTORY_SIZE = 10000
702
703
704
705
706
707
708
709
710
711

class PromptQueue:
    def __init__(self, server):
        self.server = server
        self.mutex = threading.RLock()
        self.not_empty = threading.Condition(self.mutex)
        self.task_counter = 0
        self.queue = []
        self.currently_running = {}
        self.history = {}
712
        self.flags = {}
713
714
715
716
717
718
719
720
        server.prompt_queue = self

    def put(self, item):
        with self.mutex:
            heapq.heappush(self.queue, item)
            self.server.queue_updated()
            self.not_empty.notify()

721
    def get(self, timeout=None):
722
723
        with self.not_empty:
            while len(self.queue) == 0:
724
725
726
                self.not_empty.wait(timeout=timeout)
                if timeout is not None and len(self.queue) == 0:
                    return None
727
728
729
730
731
732
733
734
735
736
            item = heapq.heappop(self.queue)
            i = self.task_counter
            self.currently_running[i] = copy.deepcopy(item)
            self.task_counter += 1
            self.server.queue_updated()
            return (item, i)

    def task_done(self, item_id, outputs):
        with self.mutex:
            prompt = self.currently_running.pop(item_id)
737
738
            if len(self.history) > MAXIMUM_HISTORY_SIZE:
                self.history.pop(next(iter(self.history)))
739
740
            self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
            for o in outputs:
741
                self.history[prompt[1]]["outputs"][o] = outputs[o]
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
            self.server.queue_updated()

    def get_current_queue(self):
        with self.mutex:
            out = []
            for x in self.currently_running.values():
                out += [x]
            return (out, copy.deepcopy(self.queue))

    def get_tasks_remaining(self):
        with self.mutex:
            return len(self.queue) + len(self.currently_running)

    def wipe_queue(self):
        with self.mutex:
            self.queue = []
            self.server.queue_updated()

    def delete_queue_item(self, function):
        with self.mutex:
            for x in range(len(self.queue)):
                if function(self.queue[x]):
                    if len(self.queue) == 1:
                        self.wipe_queue()
                    else:
                        self.queue.pop(x)
                        heapq.heapify(self.queue)
                    self.server.queue_updated()
                    return True
        return False

773
    def get_history(self, prompt_id=None, max_items=None, offset=-1):
774
        with self.mutex:
775
            if prompt_id is None:
776
777
778
779
780
781
782
783
784
785
786
                out = {}
                i = 0
                if offset < 0 and max_items is not None:
                    offset = len(self.history) - max_items
                for k in self.history:
                    if i >= offset:
                        out[k] = self.history[k]
                        if max_items is not None and len(out) >= max_items:
                            break
                    i += 1
                return out
787
788
789
790
            elif prompt_id in self.history:
                return {prompt_id: copy.deepcopy(self.history[prompt_id])}
            else:
                return {}
791
792
793
794
795
796
797
798

    def wipe_history(self):
        with self.mutex:
            self.history = {}

    def delete_history_item(self, id_to_delete):
        with self.mutex:
            self.history.pop(id_to_delete, None)
799
800
801
802
803
804
805
806
807
808
809
810
811
812

    def set_flag(self, name, data):
        with self.mutex:
            self.flags[name] = data
            self.not_empty.notify()

    def get_flags(self, reset=True):
        with self.mutex:
            if reset:
                ret = self.flags
                self.flags = {}
                return ret
            else:
                return self.flags.copy()