model_management.py 2.32 KB
Newer Older
1

2
3
4
5
6
7
8
9
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3

accelerate_enabled = False
vram_state = NORMAL_VRAM

10
11
total_vram_available_mb = -1

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys

set_vram_to = NORMAL_VRAM
if "--lowvram" in sys.argv:
    set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
    set_vram_to = NO_VRAM

if set_vram_to != NORMAL_VRAM:
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
29
30
31
32
33
34
35
    try:
        import torch
        total_vram_available_mb = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
    except:
        pass
    total_vram_available_mb = (total_vram_available_mb - 1024) // 2
    total_vram_available_mb = int(max(256, total_vram_available_mb))
36
37
38
39


print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_state])

40
41
42
43

current_loaded_model = None


44
45
46
model_accelerated = False


47
48
def unload_model():
    global current_loaded_model
49
    global model_accelerated
50
    if current_loaded_model is not None:
51
52
53
54
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

55
56
57
58
59
60
61
        current_loaded_model.model.cpu()
        current_loaded_model.unpatch_model()
        current_loaded_model = None


def load_model_gpu(model):
    global current_loaded_model
62
63
64
    global vram_state
    global model_accelerated

65
66
67
68
69
70
71
72
73
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
    current_loaded_model = model
74
75
76
77
78
79
80
81
82
    if vram_state == CPU:
        pass
    elif vram_state == NORMAL_VRAM:
        model_accelerated = False
        real_model.cuda()
    else:
        if vram_state == NO_VRAM:
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        elif vram_state == LOW_VRAM:
83
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
84
85
        accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
        model_accelerated = True
86
    return current_loaded_model