main.py 5.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import importlib.util
import folder_paths


def execute_prestartup_script():
    def execute_script(script_path):
        if os.path.exists(script_path):
            module_name = os.path.splitext(script_path)[0]
            try:
                spec = importlib.util.spec_from_file_location(module_name, script_path)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
            except Exception as e:
                print(f"Failed to execute startup-script: {script_path} / {e}")

    node_paths = folder_paths.get_folder_paths("custom_nodes")
    for custom_node_path in node_paths:
        possible_modules = os.listdir(custom_node_path)

        for possible_module in possible_modules:
            module_path = os.path.join(custom_node_path, possible_module)
            if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
                continue

            script_path = os.path.join(module_path, "prestartup_script.py")
            execute_script(script_path)


execute_prestartup_script()


# Main code
EllangoK's avatar
EllangoK committed
34
import asyncio
35
import itertools
36
import shutil
comfyanonymous's avatar
comfyanonymous committed
37
import threading
38
39
import gc
import time
40

41
from comfy.cli_args import args
42
import comfy.utils
comfyanonymous's avatar
comfyanonymous committed
43

pythongosssss's avatar
pythongosssss committed
44
45
46
47
if os.name == "nt":
    import logging
    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

comfyanonymous's avatar
comfyanonymous committed
48
if __name__ == "__main__":
EllangoK's avatar
EllangoK committed
49
50
51
52
53
54
    if args.cuda_device is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
        print("Set cuda device to:", args.cuda_device)


import yaml
55

56
import execution
EllangoK's avatar
EllangoK committed
57
import server
space-nuko's avatar
space-nuko committed
58
from server import BinaryEventTypes
EllangoK's avatar
EllangoK committed
59
from nodes import init_custom_nodes
60
import comfy.model_management
61

pythongosssss's avatar
pythongosssss committed
62
def prompt_worker(q, server):
63
    e = execution.PromptExecutor(server)
comfyanonymous's avatar
comfyanonymous committed
64
    while True:
65
        item, item_id = q.get()
66
67
68
        execution_start_time = time.perf_counter()
        prompt_id = item[1]
        e.execute(item[2], prompt_id, item[3], item[4])
69
        q.task_done(item_id, e.outputs_ui)
70
71
        if server.client_id is not None:
            server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
comfyanonymous's avatar
comfyanonymous committed
72

73
74
75
        print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
        gc.collect()
        comfy.model_management.soft_empty_cache()
reaper47's avatar
reaper47 committed
76

77
78
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
    await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
comfyanonymous's avatar
comfyanonymous committed
79

reaper47's avatar
reaper47 committed
80

pythongosssss's avatar
pythongosssss committed
81
def hijack_progress(server):
space-nuko's avatar
Fix  
space-nuko committed
82
    def hook(value, total, preview_image_bytes):
reaper47's avatar
reaper47 committed
83
        server.send_sync("progress", {"value": value, "max": total}, server.client_id)
space-nuko's avatar
Fix  
space-nuko committed
84
85
        if preview_image_bytes is not None:
            server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
86
    comfy.utils.set_progress_bar_global_hook(hook)
comfyanonymous's avatar
comfyanonymous committed
87

reaper47's avatar
reaper47 committed
88

89
90
91
def cleanup_temp():
    temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
    if os.path.exists(temp_dir):
92
        shutil.rmtree(temp_dir, ignore_errors=True)
93

reaper47's avatar
reaper47 committed
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def load_extra_path_config(yaml_path):
    with open(yaml_path, 'r') as stream:
        config = yaml.safe_load(stream)
    for c in config:
        conf = config[c]
        if conf is None:
            continue
        base_path = None
        if "base_path" in conf:
            base_path = conf.pop("base_path")
        for x in conf:
            for y in conf[x].split("\n"):
                if len(y) == 0:
                    continue
                full_path = y
                if base_path is not None:
                    full_path = os.path.join(base_path, full_path)
                print("Adding extra search path", x, full_path)
                folder_paths.add_model_folder_path(x, full_path)

reaper47's avatar
reaper47 committed
115

comfyanonymous's avatar
comfyanonymous committed
116
if __name__ == "__main__":
117
118
    cleanup_temp()

119
120
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
pythongosssss's avatar
pythongosssss committed
121
    server = server.PromptServer(loop)
122
    q = execution.PromptQueue(server)
123

124
125
126
127
128
129
130
131
    extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
    if os.path.isfile(extra_model_paths_config_path):
        load_extra_path_config(extra_model_paths_config_path)

    if args.extra_model_paths_config:
        for config_path in itertools.chain(*args.extra_model_paths_config):
            load_extra_path_config(config_path)

132
133
    init_custom_nodes()
    server.add_routes()
pythongosssss's avatar
pythongosssss committed
134
135
    hijack_progress(server)

reaper47's avatar
reaper47 committed
136
    threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
137

EllangoK's avatar
EllangoK committed
138
139
    if args.output_directory:
        output_dir = os.path.abspath(args.output_directory)
140
        print(f"Setting output directory to: {output_dir}")
141
142
        folder_paths.set_output_directory(output_dir)

EllangoK's avatar
EllangoK committed
143
    if args.quick_test_for_ci:
144
145
        exit(0)

146
    call_on_start = None
EllangoK's avatar
EllangoK committed
147
    if args.auto_launch:
148
149
        def startup_server(address, port):
            import webbrowser
reaper47's avatar
reaper47 committed
150
            webbrowser.open(f"http://{address}:{port}")
151
152
        call_on_start = startup_server

reaper47's avatar
reaper47 committed
153
    try:
EllangoK's avatar
EllangoK committed
154
        loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
reaper47's avatar
reaper47 committed
155
156
    except KeyboardInterrupt:
        print("\nStopped server")
comfyanonymous's avatar
comfyanonymous committed
157

158
    cleanup_temp()