model_management.py 3.53 KB
Newer Older
1

2
3
4
5
6
7
8
9
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3

accelerate_enabled = False
vram_state = NORMAL_VRAM

10
total_vram = 0
11
12
total_vram_available_mb = -1

13
14
15
16
import sys

set_vram_to = NORMAL_VRAM

17
18
19
try:
    import torch
    total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
20
21
22
    if total_vram <= 4096 and not "--normalvram" in sys.argv:
        print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
        set_vram_to = LOW_VRAM
23
24
25
except:
    pass

26
27
28
29
30
31
32
if "--lowvram" in sys.argv:
    set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
    set_vram_to = NO_VRAM



33
34
35
36
37
38
39
40
41
if set_vram_to != NORMAL_VRAM:
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
42
43

    total_vram_available_mb = (total_vram - 1024) // 2
44
    total_vram_available_mb = int(max(256, total_vram_available_mb))
45
46
47
48


print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])

49
50

current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
51
current_gpu_controlnets = []
52

53
54
55
model_accelerated = False


56
57
def unload_model():
    global current_loaded_model
58
    global model_accelerated
comfyanonymous's avatar
comfyanonymous committed
59
    global current_gpu_controlnets
60
    if current_loaded_model is not None:
61
62
63
64
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

65
66
67
        current_loaded_model.model.cpu()
        current_loaded_model.unpatch_model()
        current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
68
69
70
71
    if len(current_gpu_controlnets) > 0:
        for n in current_gpu_controlnets:
            n.cpu()
        current_gpu_controlnets = []
72
73
74
75


def load_model_gpu(model):
    global current_loaded_model
76
77
78
    global vram_state
    global model_accelerated

79
80
81
82
83
84
85
86
87
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
    current_loaded_model = model
88
89
90
91
92
93
94
95
96
    if vram_state == CPU:
        pass
    elif vram_state == NORMAL_VRAM:
        model_accelerated = False
        real_model.cuda()
    else:
        if vram_state == NO_VRAM:
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        elif vram_state == LOW_VRAM:
97
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
comfyanonymous's avatar
comfyanonymous committed
98

99
100
        accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
        model_accelerated = True
101
    return current_loaded_model
102

comfyanonymous's avatar
comfyanonymous committed
103
104
105
106
107
108
109
110
111
112
def load_controlnet_gpu(models):
    global current_gpu_controlnets
    for m in current_gpu_controlnets:
        if m not in models:
            m.cpu()

    current_gpu_controlnets = []
    for m in models:
        current_gpu_controlnets.append(m.cuda())

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

def get_free_memory():
    dev = torch.cuda.current_device()
    stats = torch.cuda.memory_stats(dev)
    mem_active = stats['active_bytes.all.current']
    mem_reserved = stats['reserved_bytes.all.current']
    mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
    mem_free_torch = mem_reserved - mem_active
    return mem_free_cuda + mem_free_torch

def maximum_batch_area():
    global vram_state
    if vram_state == NO_VRAM:
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
    area = ((memory_free - 1024) * 0.9) / (0.6)
    return int(max(area, 0))