main.py 3.37 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
import os
import sys
3

comfyanonymous's avatar
comfyanonymous committed
4
import threading
5
import asyncio
comfyanonymous's avatar
comfyanonymous committed
6

pythongosssss's avatar
pythongosssss committed
7
8
9
10
if os.name == "nt":
    import logging
    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

11
import execution
pythongosssss's avatar
pythongosssss committed
12
import server
13

comfyanonymous's avatar
comfyanonymous committed
14
15
16
17
18
19
20
21
if __name__ == "__main__":
    if '--help' in sys.argv:
        print("Valid Command line Arguments:")
        print("\t--listen\t\t\tListen on 0.0.0.0 so the UI can be accessed from other computers.")
        print("\t--port 8188\t\t\tSet the listen port.")
        print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
        print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
        print()
22
        print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
23
        print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
24
25
26
        print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
        print("\t--novram\t\t\tWhen lowvram isn't enough.")
        print()
27
        print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
comfyanonymous's avatar
comfyanonymous committed
28
29
        exit()

pythongosssss's avatar
pythongosssss committed
30
31
32
    if '--dont-upcast-attention' in sys.argv:
        print("disabling upcasting of attention")
        os.environ['ATTN_PRECISION'] = "fp16"
33

pythongosssss's avatar
pythongosssss committed
34
def prompt_worker(q, server):
35
    e = execution.PromptExecutor(server)
comfyanonymous's avatar
comfyanonymous committed
36
    while True:
37
        item, item_id = q.get()
comfyanonymous's avatar
comfyanonymous committed
38
        e.execute(item[-2], item[-1])
pythongosssss's avatar
pythongosssss committed
39
        q.task_done(item_id, e.outputs)
comfyanonymous's avatar
comfyanonymous committed
40

41
42
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
    await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
comfyanonymous's avatar
comfyanonymous committed
43

pythongosssss's avatar
pythongosssss committed
44
45
46
47
48
49
50
51
52
def hijack_progress(server):
    from tqdm.auto import tqdm
    orig_func = getattr(tqdm, "update")
    def wrapped_func(*args, **kwargs):
        pbar = args[0]
        v = orig_func(*args, **kwargs)
        server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)            
        return v
    setattr(tqdm, "update", wrapped_func)
comfyanonymous's avatar
comfyanonymous committed
53
54

if __name__ == "__main__":
55
56
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
pythongosssss's avatar
pythongosssss committed
57
    server = server.PromptServer(loop)
58
    q = execution.PromptQueue(server)
59

pythongosssss's avatar
pythongosssss committed
60
61
62
    hijack_progress(server)

    threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
63
64
65
66
    if '--listen' in sys.argv:
        address = '0.0.0.0'
    else:
        address = '127.0.0.1'
67

comfyanonymous's avatar
comfyanonymous committed
68
69
70
71
    dont_print = False
    if '--dont-print-server' in sys.argv:
        dont_print = True

72
73
74
75
76
77
78
    port = 8188
    try:
        p_index = sys.argv.index('--port')
        port = int(sys.argv[p_index + 1])
    except:
        pass

79
80
81
82
83
84
85
    call_on_start = None
    if "--windows-standalone-build" in sys.argv:
        def startup_server(address, port):
            import webbrowser
            webbrowser.open("http://{}:{}".format(address, port))
        call_on_start = startup_server

pythongosssss's avatar
pythongosssss committed
86
87
    if os.name == "nt":
        try:
88
            loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
pythongosssss's avatar
pythongosssss committed
89
90
91
        except KeyboardInterrupt:
            pass
    else:
92
        loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start))
comfyanonymous's avatar
comfyanonymous committed
93