model_management.py 4.04 KB
Newer Older
1

2
3
4
5
6
7
8
9
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3

accelerate_enabled = False
vram_state = NORMAL_VRAM

10
total_vram = 0
11
12
total_vram_available_mb = -1

13
14
15
16
import sys

set_vram_to = NORMAL_VRAM

17
18
19
try:
    import torch
    total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
20
21
22
    if total_vram <= 4096 and not "--normalvram" in sys.argv:
        print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
        set_vram_to = LOW_VRAM
23
24
25
except:
    pass

26
27
28
29
30
31
32
if "--lowvram" in sys.argv:
    set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
    set_vram_to = NO_VRAM



33
34
35
36
37
38
39
40
41
if set_vram_to != NORMAL_VRAM:
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
42
43

    total_vram_available_mb = (total_vram - 1024) // 2
44
    total_vram_available_mb = int(max(256, total_vram_available_mb))
45
46
47
48


print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])

49
50

current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
51
current_gpu_controlnets = []
52

53
54
55
model_accelerated = False


56
57
def unload_model():
    global current_loaded_model
58
    global model_accelerated
comfyanonymous's avatar
comfyanonymous committed
59
    global current_gpu_controlnets
60
    if current_loaded_model is not None:
61
62
63
64
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

65
66
67
        current_loaded_model.model.cpu()
        current_loaded_model.unpatch_model()
        current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
68
69
70
71
    if len(current_gpu_controlnets) > 0:
        for n in current_gpu_controlnets:
            n.cpu()
        current_gpu_controlnets = []
72
73
74
75


def load_model_gpu(model):
    global current_loaded_model
76
77
78
    global vram_state
    global model_accelerated

79
80
81
82
83
84
85
86
87
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
    current_loaded_model = model
88
89
90
91
92
93
94
95
96
    if vram_state == CPU:
        pass
    elif vram_state == NORMAL_VRAM:
        model_accelerated = False
        real_model.cuda()
    else:
        if vram_state == NO_VRAM:
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        elif vram_state == LOW_VRAM:
97
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
comfyanonymous's avatar
comfyanonymous committed
98

99
100
        accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
        model_accelerated = True
101
    return current_loaded_model
102

comfyanonymous's avatar
comfyanonymous committed
103
104
def load_controlnet_gpu(models):
    global current_gpu_controlnets
105
106
107
108
109
110
    global vram_state

    if vram_state == LOW_VRAM or vram_state == NO_VRAM:
        #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
        return

comfyanonymous's avatar
comfyanonymous committed
111
112
113
114
115
116
117
118
    for m in current_gpu_controlnets:
        if m not in models:
            m.cpu()

    current_gpu_controlnets = []
    for m in models:
        current_gpu_controlnets.append(m.cuda())

119

120
121
122
123
124
125
126
127
128
129
130
131
132
def load_if_low_vram(model):
    global vram_state
    if vram_state == LOW_VRAM or vram_state == NO_VRAM:
        return model.cuda()
    return model

def unload_if_low_vram(model):
    global vram_state
    if vram_state == LOW_VRAM or vram_state == NO_VRAM:
        return model.cpu()
    return model


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def get_free_memory():
    dev = torch.cuda.current_device()
    stats = torch.cuda.memory_stats(dev)
    mem_active = stats['active_bytes.all.current']
    mem_reserved = stats['reserved_bytes.all.current']
    mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
    mem_free_torch = mem_reserved - mem_active
    return mem_free_cuda + mem_free_torch

def maximum_batch_area():
    global vram_state
    if vram_state == NO_VRAM:
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
    area = ((memory_free - 1024) * 0.9) / (0.6)
    return int(max(area, 0))