main.py 9.29 KB
Newer Older
1
2
3
import comfy.options
comfy.options.enable_args_parsing()

4
5
6
import os
import importlib.util
import folder_paths
7
import time
8
9
from comfy.cli_args import args

10
11
12

def execute_prestartup_script():
    def execute_script(script_path):
13
14
15
16
17
18
19
20
21
        module_name = os.path.splitext(script_path)[0]
        try:
            spec = importlib.util.spec_from_file_location(module_name, script_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            return True
        except Exception as e:
            print(f"Failed to execute startup-script: {script_path} / {e}")
        return False
22

23
24
25
    if args.disable_all_custom_nodes:
        return

26
27
28
    node_paths = folder_paths.get_folder_paths("custom_nodes")
    for custom_node_path in node_paths:
        possible_modules = os.listdir(custom_node_path)
29
        node_prestartup_times = []
30
31
32
33
34
35
36

        for possible_module in possible_modules:
            module_path = os.path.join(custom_node_path, possible_module)
            if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
                continue

            script_path = os.path.join(module_path, "prestartup_script.py")
37
38
39
40
41
42
43
44
45
46
47
48
49
            if os.path.exists(script_path):
                time_before = time.perf_counter()
                success = execute_script(script_path)
                node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
    if len(node_prestartup_times) > 0:
        print("\nPrestartup times for custom nodes:")
        for n in sorted(node_prestartup_times):
            if n[2]:
                import_message = ""
            else:
                import_message = " (PRESTARTUP FAILED)"
            print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
        print()
50
51
52
53
54

execute_prestartup_script()


# Main code
EllangoK's avatar
EllangoK committed
55
import asyncio
56
import itertools
57
import shutil
comfyanonymous's avatar
comfyanonymous committed
58
import threading
59
import gc
60

61
import logging
comfyanonymous's avatar
comfyanonymous committed
62

pythongosssss's avatar
pythongosssss committed
63
64
65
if os.name == "nt":
    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

comfyanonymous's avatar
comfyanonymous committed
66
if __name__ == "__main__":
EllangoK's avatar
EllangoK committed
67
68
    if args.cuda_device is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
69
        logging.info("Set cuda device to: {}".format(args.cuda_device))
EllangoK's avatar
EllangoK committed
70

71
72
73
74
    if args.deterministic:
        if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
            os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

75
    import cuda_malloc
EllangoK's avatar
EllangoK committed
76

77
import comfy.utils
EllangoK's avatar
EllangoK committed
78
import yaml
79

80
import execution
EllangoK's avatar
EllangoK committed
81
import server
space-nuko's avatar
space-nuko committed
82
from server import BinaryEventTypes
83
import nodes
84
import comfy.model_management
85

86
87
88
89
90
91
92
93
94
def cuda_malloc_warning():
    device = comfy.model_management.get_torch_device()
    device_name = comfy.model_management.get_torch_device_name(device)
    cuda_malloc_warning = False
    if "cudaMallocAsync" in device_name:
        for b in cuda_malloc.blacklist:
            if b in device_name:
                cuda_malloc_warning = True
        if cuda_malloc_warning:
95
            logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
96

pythongosssss's avatar
pythongosssss committed
97
def prompt_worker(q, server):
98
    e = execution.PromptExecutor(server)
99
    last_gc_collect = 0
100
101
102
    need_gc = False
    gc_collect_interval = 10.0

comfyanonymous's avatar
comfyanonymous committed
103
    while True:
104
        timeout = 1000.0
105
106
107
108
109
110
111
112
        if need_gc:
            timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)

        queue_item = q.get(timeout=timeout)
        if queue_item is not None:
            item, item_id = queue_item
            execution_start_time = time.perf_counter()
            prompt_id = item[1]
113
114
            server.last_prompt_id = prompt_id

115
116
            e.execute(item[2], prompt_id, item[3], item[4])
            need_gc = True
117
118
119
120
121
            q.task_done(item_id,
                        e.outputs_ui,
                        status=execution.PromptQueue.ExecutionStatus(
                            status_str='success' if e.success else 'error',
                            completed=e.success,
122
                            messages=e.status_messages))
123
124
125
126
127
            if server.client_id is not None:
                server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)

            current_time = time.perf_counter()
            execution_time = current_time - execution_start_time
128
            logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
129

130
131
132
133
134
135
136
137
138
139
140
141
142
        flags = q.get_flags()
        free_memory = flags.get("free_memory", False)

        if flags.get("unload_models", free_memory):
            comfy.model_management.unload_all_models()
            need_gc = True
            last_gc_collect = 0

        if free_memory:
            e.reset()
            need_gc = True
            last_gc_collect = 0

143
144
145
        if need_gc:
            current_time = time.perf_counter()
            if (current_time - last_gc_collect) > gc_collect_interval:
146
                comfy.model_management.cleanup_models()
147
148
149
150
                gc.collect()
                comfy.model_management.soft_empty_cache()
                last_gc_collect = current_time
                need_gc = False
reaper47's avatar
reaper47 committed
151

152
153
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
    await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
comfyanonymous's avatar
comfyanonymous committed
154

reaper47's avatar
reaper47 committed
155

pythongosssss's avatar
pythongosssss committed
156
def hijack_progress(server):
157
    def hook(value, total, preview_image):
158
        comfy.model_management.throw_exception_if_processing_interrupted()
159
160
161
        progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}

        server.send_sync("progress", progress, server.client_id)
162
163
        if preview_image is not None:
            server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
164
    comfy.utils.set_progress_bar_global_hook(hook)
comfyanonymous's avatar
comfyanonymous committed
165

reaper47's avatar
reaper47 committed
166

167
def cleanup_temp():
168
    temp_dir = folder_paths.get_temp_directory()
169
    if os.path.exists(temp_dir):
170
        shutil.rmtree(temp_dir, ignore_errors=True)
171

reaper47's avatar
reaper47 committed
172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def load_extra_path_config(yaml_path):
    with open(yaml_path, 'r') as stream:
        config = yaml.safe_load(stream)
    for c in config:
        conf = config[c]
        if conf is None:
            continue
        base_path = None
        if "base_path" in conf:
            base_path = conf.pop("base_path")
        for x in conf:
            for y in conf[x].split("\n"):
                if len(y) == 0:
                    continue
                full_path = y
                if base_path is not None:
                    full_path = os.path.join(base_path, full_path)
190
                logging.info("Adding extra search path {} {}".format(x, full_path))
191
192
                folder_paths.add_model_folder_path(x, full_path)

reaper47's avatar
reaper47 committed
193

comfyanonymous's avatar
comfyanonymous committed
194
if __name__ == "__main__":
195
196
    if args.temp_directory:
        temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
197
        logging.info(f"Setting temp directory to: {temp_dir}")
198
        folder_paths.set_temp_directory(temp_dir)
199
200
    cleanup_temp()

201
202
203
204
205
206
207
    if args.windows_standalone_build:
        try:
            import new_updater
            new_updater.update_windows_updater()
        except:
            pass

208
209
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
pythongosssss's avatar
pythongosssss committed
210
    server = server.PromptServer(loop)
211
    q = execution.PromptQueue(server)
212

213
214
215
216
217
218
219
220
    extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
    if os.path.isfile(extra_model_paths_config_path):
        load_extra_path_config(extra_model_paths_config_path)

    if args.extra_model_paths_config:
        for config_path in itertools.chain(*args.extra_model_paths_config):
            load_extra_path_config(config_path)

221
    nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
222
223
224

    cuda_malloc_warning()

225
    server.add_routes()
pythongosssss's avatar
pythongosssss committed
226
227
    hijack_progress(server)

reaper47's avatar
reaper47 committed
228
    threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
229

EllangoK's avatar
EllangoK committed
230
231
    if args.output_directory:
        output_dir = os.path.abspath(args.output_directory)
232
        logging.info(f"Setting output directory to: {output_dir}")
233
234
        folder_paths.set_output_directory(output_dir)

235
236
237
238
239
    #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
    folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
    folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
    folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))

Jairo Correa's avatar
Jairo Correa committed
240
241
    if args.input_directory:
        input_dir = os.path.abspath(args.input_directory)
242
        logging.info(f"Setting input directory to: {input_dir}")
Jairo Correa's avatar
Jairo Correa committed
243
244
        folder_paths.set_input_directory(input_dir)

EllangoK's avatar
EllangoK committed
245
    if args.quick_test_for_ci:
246
247
        exit(0)

248
    call_on_start = None
EllangoK's avatar
EllangoK committed
249
    if args.auto_launch:
Garrett Sutula's avatar
Garrett Sutula committed
250
        def startup_server(scheme, address, port):
251
            import webbrowser
252
253
            if os.name == 'nt' and address == '0.0.0.0':
                address = '127.0.0.1'
Garrett Sutula's avatar
Garrett Sutula committed
254
            webbrowser.open(f"{scheme}://{address}:{port}")
255
256
        call_on_start = startup_server

reaper47's avatar
reaper47 committed
257
    try:
EllangoK's avatar
EllangoK committed
258
        loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
reaper47's avatar
reaper47 committed
259
    except KeyboardInterrupt:
260
        logging.info("\nStopped server")
comfyanonymous's avatar
comfyanonymous committed
261

262
    cleanup_temp()