main.py 2.96 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
import os
import sys
3

comfyanonymous's avatar
comfyanonymous committed
4
import threading
5
import asyncio
comfyanonymous's avatar
comfyanonymous committed
6

pythongosssss's avatar
pythongosssss committed
7
8
9
10
if os.name == "nt":
    import logging
    logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

11
import execution
pythongosssss's avatar
pythongosssss committed
12
import server
13

comfyanonymous's avatar
comfyanonymous committed
14
15
16
17
18
19
20
21
if __name__ == "__main__":
    if '--help' in sys.argv:
        print("Valid Command line Arguments:")
        print("\t--listen\t\t\tListen on 0.0.0.0 so the UI can be accessed from other computers.")
        print("\t--port 8188\t\t\tSet the listen port.")
        print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
        print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
        print()
22
        print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
23
        print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
24
25
26
        print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
        print("\t--novram\t\t\tWhen lowvram isn't enough.")
        print()
comfyanonymous's avatar
comfyanonymous committed
27
28
        exit()

pythongosssss's avatar
pythongosssss committed
29
30
31
    if '--dont-upcast-attention' in sys.argv:
        print("disabling upcasting of attention")
        os.environ['ATTN_PRECISION'] = "fp16"
32

pythongosssss's avatar
pythongosssss committed
33
def prompt_worker(q, server):
34
    e = execution.PromptExecutor(server)
comfyanonymous's avatar
comfyanonymous committed
35
    while True:
36
        item, item_id = q.get()
comfyanonymous's avatar
comfyanonymous committed
37
        e.execute(item[-2], item[-1])
pythongosssss's avatar
pythongosssss committed
38
        q.task_done(item_id, e.outputs)
comfyanonymous's avatar
comfyanonymous committed
39

comfyanonymous's avatar
comfyanonymous committed
40
41
async def run(server, address='', port=8188, verbose=True):
    await asyncio.gather(server.start(address, port, verbose), server.publish_loop())
comfyanonymous's avatar
comfyanonymous committed
42

pythongosssss's avatar
pythongosssss committed
43
44
45
46
47
48
49
50
51
def hijack_progress(server):
    from tqdm.auto import tqdm
    orig_func = getattr(tqdm, "update")
    def wrapped_func(*args, **kwargs):
        pbar = args[0]
        v = orig_func(*args, **kwargs)
        server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id)            
        return v
    setattr(tqdm, "update", wrapped_func)
comfyanonymous's avatar
comfyanonymous committed
52
53

if __name__ == "__main__":
54
55
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
pythongosssss's avatar
pythongosssss committed
56
    server = server.PromptServer(loop)
57
    q = execution.PromptQueue(server)
58

pythongosssss's avatar
pythongosssss committed
59
60
61
    hijack_progress(server)

    threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start()
62
63
64
65
    if '--listen' in sys.argv:
        address = '0.0.0.0'
    else:
        address = '127.0.0.1'
66

comfyanonymous's avatar
comfyanonymous committed
67
68
69
70
    dont_print = False
    if '--dont-print-server' in sys.argv:
        dont_print = True

71
72
73
74
75
76
77
    port = 8188
    try:
        p_index = sys.argv.index('--port')
        port = int(sys.argv[p_index + 1])
    except:
        pass

pythongosssss's avatar
pythongosssss committed
78
79
    if os.name == "nt":
        try:
comfyanonymous's avatar
comfyanonymous committed
80
            loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
pythongosssss's avatar
pythongosssss committed
81
82
83
        except KeyboardInterrupt:
            pass
    else:
comfyanonymous's avatar
comfyanonymous committed
84
        loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print))
comfyanonymous's avatar
comfyanonymous committed
85