main.py 9.38 KB
Newer Older
1
2
3
import comfy.options
comfy.options.enable_args_parsing()

4
5
6
import os
import importlib.util
import folder_paths
7
import time
8
9
from comfy.cli_args import args

10
11
12

def execute_prestartup_script():
    def execute_script(script_path):
13
14
15
16
17
18
19
20
21
        module_name = os.path.splitext(script_path)[0]
        try:
            spec = importlib.util.spec_from_file_location(module_name, script_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            return True
        except Exception as e:
            print(f"Failed to execute startup-script: {script_path} / {e}")
        return False
22

23
24
25
    if args.disable_all_custom_nodes:
        return

26
27
28
    node_paths = folder_paths.get_folder_paths("custom_nodes")
    for custom_node_path in node_paths:
        possible_modules = os.listdir(custom_node_path)
29
        node_prestartup_times = []
30
31
32
33
34
35
36

        for possible_module in possible_modules:
            module_path = os.path.join(custom_node_path, possible_module)
            if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
                continue

            script_path = os.path.join(module_path, "prestartup_script.py")
37
38
39
40
41
42
43
44
45
46
47
48
49
            if os.path.exists(script_path):
                time_before = time.perf_counter()
                success = execute_script(script_path)
                node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
    if len(node_prestartup_times) > 0:
        print("\nPrestartup times for custom nodes:")
        for n in sorted(node_prestartup_times):
            if n[2]:
                import_message = ""
            else:
                import_message = " (PRESTARTUP FAILED)"
            print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
        print()
50
51
52
53
54

execute_prestartup_script()


# Main code
EllangoK's avatar
EllangoK committed
55
import asyncio
56
import itertools
57
import shutil
comfyanonymous's avatar
comfyanonymous committed
58
import threading
59
import gc
60

61
import logging
comfyanonymous's avatar
comfyanonymous committed
62

pythongosssss's avatar
pythongosssss committed
63
64
65
if os.name == "nt":
    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

comfyanonymous's avatar
comfyanonymous committed
66
if __name__ == "__main__":
EllangoK's avatar
EllangoK committed
67
68
    if args.cuda_device is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
69
        logging.info("Set cuda device to: {}".format(args.cuda_device))
EllangoK's avatar
EllangoK committed
70

71
72
73
74
    if args.deterministic:
        if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
            os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

75
    import cuda_malloc
EllangoK's avatar
EllangoK committed
76

77
78
79
80
81
82
if args.windows_standalone_build:
    try:
        import fix_torch
    except:
        pass

83
import comfy.utils
EllangoK's avatar
EllangoK committed
84
import yaml
85

86
import execution
EllangoK's avatar
EllangoK committed
87
import server
space-nuko's avatar
space-nuko committed
88
from server import BinaryEventTypes
89
import nodes
90
import comfy.model_management
91

92
93
94
95
96
97
98
99
100
def cuda_malloc_warning():
    device = comfy.model_management.get_torch_device()
    device_name = comfy.model_management.get_torch_device_name(device)
    cuda_malloc_warning = False
    if "cudaMallocAsync" in device_name:
        for b in cuda_malloc.blacklist:
            if b in device_name:
                cuda_malloc_warning = True
        if cuda_malloc_warning:
101
            logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
102

pythongosssss's avatar
pythongosssss committed
103
def prompt_worker(q, server):
104
    e = execution.PromptExecutor(server)
105
    last_gc_collect = 0
106
107
108
    need_gc = False
    gc_collect_interval = 10.0

comfyanonymous's avatar
comfyanonymous committed
109
    while True:
110
        timeout = 1000.0
111
112
113
114
115
116
117
118
        if need_gc:
            timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)

        queue_item = q.get(timeout=timeout)
        if queue_item is not None:
            item, item_id = queue_item
            execution_start_time = time.perf_counter()
            prompt_id = item[1]
119
120
            server.last_prompt_id = prompt_id

121
122
            e.execute(item[2], prompt_id, item[3], item[4])
            need_gc = True
123
124
125
126
127
            q.task_done(item_id,
                        e.outputs_ui,
                        status=execution.PromptQueue.ExecutionStatus(
                            status_str='success' if e.success else 'error',
                            completed=e.success,
128
                            messages=e.status_messages))
129
130
131
132
133
            if server.client_id is not None:
                server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)

            current_time = time.perf_counter()
            execution_time = current_time - execution_start_time
134
            logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
135

136
137
138
139
140
141
142
143
144
145
146
147
148
        flags = q.get_flags()
        free_memory = flags.get("free_memory", False)

        if flags.get("unload_models", free_memory):
            comfy.model_management.unload_all_models()
            need_gc = True
            last_gc_collect = 0

        if free_memory:
            e.reset()
            need_gc = True
            last_gc_collect = 0

149
150
151
        if need_gc:
            current_time = time.perf_counter()
            if (current_time - last_gc_collect) > gc_collect_interval:
152
                comfy.model_management.cleanup_models()
153
154
155
156
                gc.collect()
                comfy.model_management.soft_empty_cache()
                last_gc_collect = current_time
                need_gc = False
reaper47's avatar
reaper47 committed
157

158
159
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
    await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
comfyanonymous's avatar
comfyanonymous committed
160

reaper47's avatar
reaper47 committed
161

pythongosssss's avatar
pythongosssss committed
162
def hijack_progress(server):
163
    def hook(value, total, preview_image):
164
        comfy.model_management.throw_exception_if_processing_interrupted()
165
166
167
        progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}

        server.send_sync("progress", progress, server.client_id)
168
169
        if preview_image is not None:
            server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
170
    comfy.utils.set_progress_bar_global_hook(hook)
comfyanonymous's avatar
comfyanonymous committed
171

reaper47's avatar
reaper47 committed
172

173
def cleanup_temp():
174
    temp_dir = folder_paths.get_temp_directory()
175
    if os.path.exists(temp_dir):
176
        shutil.rmtree(temp_dir, ignore_errors=True)
177

reaper47's avatar
reaper47 committed
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def load_extra_path_config(yaml_path):
    with open(yaml_path, 'r') as stream:
        config = yaml.safe_load(stream)
    for c in config:
        conf = config[c]
        if conf is None:
            continue
        base_path = None
        if "base_path" in conf:
            base_path = conf.pop("base_path")
        for x in conf:
            for y in conf[x].split("\n"):
                if len(y) == 0:
                    continue
                full_path = y
                if base_path is not None:
                    full_path = os.path.join(base_path, full_path)
196
                logging.info("Adding extra search path {} {}".format(x, full_path))
197
198
                folder_paths.add_model_folder_path(x, full_path)

reaper47's avatar
reaper47 committed
199

comfyanonymous's avatar
comfyanonymous committed
200
if __name__ == "__main__":
201
202
    if args.temp_directory:
        temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
203
        logging.info(f"Setting temp directory to: {temp_dir}")
204
        folder_paths.set_temp_directory(temp_dir)
205
206
    cleanup_temp()

207
208
209
210
211
212
213
    if args.windows_standalone_build:
        try:
            import new_updater
            new_updater.update_windows_updater()
        except:
            pass

214
215
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
pythongosssss's avatar
pythongosssss committed
216
    server = server.PromptServer(loop)
217
    q = execution.PromptQueue(server)
218

219
220
221
222
223
224
225
226
    extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
    if os.path.isfile(extra_model_paths_config_path):
        load_extra_path_config(extra_model_paths_config_path)

    if args.extra_model_paths_config:
        for config_path in itertools.chain(*args.extra_model_paths_config):
            load_extra_path_config(config_path)

227
    nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
228
229
230

    cuda_malloc_warning()

231
    server.add_routes()
pythongosssss's avatar
pythongosssss committed
232
233
    hijack_progress(server)

reaper47's avatar
reaper47 committed
234
    threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
235

EllangoK's avatar
EllangoK committed
236
237
    if args.output_directory:
        output_dir = os.path.abspath(args.output_directory)
238
        logging.info(f"Setting output directory to: {output_dir}")
239
240
        folder_paths.set_output_directory(output_dir)

241
242
243
244
245
    #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
    folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
    folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
    folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))

Jairo Correa's avatar
Jairo Correa committed
246
247
    if args.input_directory:
        input_dir = os.path.abspath(args.input_directory)
248
        logging.info(f"Setting input directory to: {input_dir}")
Jairo Correa's avatar
Jairo Correa committed
249
250
        folder_paths.set_input_directory(input_dir)

EllangoK's avatar
EllangoK committed
251
    if args.quick_test_for_ci:
252
253
        exit(0)

254
    call_on_start = None
EllangoK's avatar
EllangoK committed
255
    if args.auto_launch:
Garrett Sutula's avatar
Garrett Sutula committed
256
        def startup_server(scheme, address, port):
257
            import webbrowser
258
259
            if os.name == 'nt' and address == '0.0.0.0':
                address = '127.0.0.1'
Garrett Sutula's avatar
Garrett Sutula committed
260
            webbrowser.open(f"{scheme}://{address}:{port}")
261
262
        call_on_start = startup_server

reaper47's avatar
reaper47 committed
263
    try:
EllangoK's avatar
EllangoK committed
264
        loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
reaper47's avatar
reaper47 committed
265
    except KeyboardInterrupt:
266
        logging.info("\nStopped server")
comfyanonymous's avatar
comfyanonymous committed
267

268
    cleanup_temp()