model_management.py 2.92 KB
Newer Older
1

2
3
4
5
6
7
8
9
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3

accelerate_enabled = False
vram_state = NORMAL_VRAM

10
total_vram = 0
11
12
total_vram_available_mb = -1

13
14
15
16
17
18
19
20
import sys

set_vram_to = NORMAL_VRAM
if "--lowvram" in sys.argv:
    set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
    set_vram_to = NO_VRAM

21
22
23
24
25
26
try:
    import torch
    total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
except:
    pass

27
28
29
30
31
32
33
34
35
if set_vram_to != NORMAL_VRAM:
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
36
37

    total_vram_available_mb = (total_vram - 1024) // 2
38
    total_vram_available_mb = int(max(256, total_vram_available_mb))
39
40
41
42


print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])

43
44
45
46

current_loaded_model = None


47
48
49
model_accelerated = False


50
51
def unload_model():
    global current_loaded_model
52
    global model_accelerated
53
    if current_loaded_model is not None:
54
55
56
57
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

58
59
60
61
62
63
64
        current_loaded_model.model.cpu()
        current_loaded_model.unpatch_model()
        current_loaded_model = None


def load_model_gpu(model):
    global current_loaded_model
65
66
67
    global vram_state
    global model_accelerated

68
69
70
71
72
73
74
75
76
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
    current_loaded_model = model
77
78
79
80
81
82
83
84
85
    if vram_state == CPU:
        pass
    elif vram_state == NORMAL_VRAM:
        model_accelerated = False
        real_model.cuda()
    else:
        if vram_state == NO_VRAM:
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        elif vram_state == LOW_VRAM:
86
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
87
        print(device_map, "{}MiB".format(total_vram_available_mb))
88
89
        accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
        model_accelerated = True
90
    return current_loaded_model
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def get_free_memory():
    dev = torch.cuda.current_device()
    stats = torch.cuda.memory_stats(dev)
    mem_active = stats['active_bytes.all.current']
    mem_reserved = stats['reserved_bytes.all.current']
    mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
    mem_free_torch = mem_reserved - mem_active
    return mem_free_cuda + mem_free_torch

def maximum_batch_area():
    global vram_state
    if vram_state == NO_VRAM:
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
    area = ((memory_free - 1024) * 0.9) / (0.6)
    return int(max(area, 0))