model_management.py 3.07 KB
Newer Older
1

2
3
4
5
6
7
8
9
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3

accelerate_enabled = False
vram_state = NORMAL_VRAM

10
total_vram = 0
11
12
total_vram_available_mb = -1

13
14
15
16
import sys

set_vram_to = NORMAL_VRAM

17
18
19
try:
    import torch
    total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
20
21
22
    if total_vram <= 4096 and not "--normalvram" in sys.argv:
        print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
        set_vram_to = LOW_VRAM
23
24
25
except:
    pass

26
27
28
29
30
31
32
if "--lowvram" in sys.argv:
    set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
    set_vram_to = NO_VRAM



33
34
35
36
37
38
39
40
41
if set_vram_to != NORMAL_VRAM:
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
42
43

    total_vram_available_mb = (total_vram - 1024) // 2
44
    total_vram_available_mb = int(max(256, total_vram_available_mb))
45
46
47
48


print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])

49
50
51
52

current_loaded_model = None


53
54
55
model_accelerated = False


56
57
def unload_model():
    global current_loaded_model
58
    global model_accelerated
59
    if current_loaded_model is not None:
60
61
62
63
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

64
65
66
67
68
69
70
        current_loaded_model.model.cpu()
        current_loaded_model.unpatch_model()
        current_loaded_model = None


def load_model_gpu(model):
    global current_loaded_model
71
72
73
    global vram_state
    global model_accelerated

74
75
76
77
78
79
80
81
82
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
    current_loaded_model = model
83
84
85
86
87
88
89
90
91
    if vram_state == CPU:
        pass
    elif vram_state == NORMAL_VRAM:
        model_accelerated = False
        real_model.cuda()
    else:
        if vram_state == NO_VRAM:
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        elif vram_state == LOW_VRAM:
92
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
comfyanonymous's avatar
comfyanonymous committed
93

94
95
        accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
        model_accelerated = True
96
    return current_loaded_model
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115


def get_free_memory():
    dev = torch.cuda.current_device()
    stats = torch.cuda.memory_stats(dev)
    mem_active = stats['active_bytes.all.current']
    mem_reserved = stats['reserved_bytes.all.current']
    mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
    mem_free_torch = mem_reserved - mem_active
    return mem_free_cuda + mem_free_torch

def maximum_batch_area():
    global vram_state
    if vram_state == NO_VRAM:
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
    area = ((memory_free - 1024) * 0.9) / (0.6)
    return int(max(area, 0))