model_management.py 9.12 KB
Newer Older
1

2
3
4
5
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
6
HIGH_VRAM = 4
Yurii Mazurevich's avatar
Yurii Mazurevich committed
7
MPS = 5
8
9

accelerate_enabled = False
10
xpu_available = False
11
12
vram_state = NORMAL_VRAM

13
total_vram = 0
14
15
total_vram_available_mb = -1

16
import sys
17
import psutil
18

Francesco Yoshi Gobbo's avatar
Francesco Yoshi Gobbo committed
19
20
forced_cpu = "--cpu" in sys.argv

21
22
set_vram_to = NORMAL_VRAM

23
24
try:
    import torch
25
26
27
28
29
30
    import intel_extension_for_pytorch as ipex
    if torch.xpu.is_available():
        xpu_available = True
        total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
    else:
        total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
31
32
    total_ram = psutil.virtual_memory().total / (1024 * 1024)
    forced_normal_vram = "--normalvram" in sys.argv
33
    if not forced_normal_vram and not forced_cpu:
34
35
36
        if total_vram <= 4096:
            print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
            set_vram_to = LOW_VRAM
comfyanonymous's avatar
comfyanonymous committed
37
        elif total_vram > total_ram * 1.1 and total_vram > 14336:
38
39
            print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
            vram_state = HIGH_VRAM
40
41
42
except:
    pass

43
44
45
46
47
try:
    OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
    OOM_EXCEPTION = Exception

48
49
if "--disable-xformers" in sys.argv:
    XFORMERS_IS_AVAILBLE = False
50
51
52
53
54
55
56
57
else:
    try:
        import xformers
        import xformers.ops
        XFORMERS_IS_AVAILBLE = True
    except:
        XFORMERS_IS_AVAILBLE = False

58
59
60
61
62
63
64
65
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
    torch.backends.cuda.enable_math_sdp(True)
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(True)
    ENABLE_PYTORCH_ATTENTION = True
    XFORMERS_IS_AVAILBLE = False

66

67
68
69
70
if "--lowvram" in sys.argv:
    set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
    set_vram_to = NO_VRAM
71
72
if "--highvram" in sys.argv:
    vram_state = HIGH_VRAM
73

74

75
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
76
77
78
79
80
81
82
83
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
84
85

    total_vram_available_mb = (total_vram - 1024) // 2
86
    total_vram_available_mb = int(max(256, total_vram_available_mb))
87

88
89
90
91
92
93
try:
    if torch.backends.mps.is_available():
        vram_state = MPS
except:
    pass

Francesco Yoshi Gobbo's avatar
Francesco Yoshi Gobbo committed
94
if forced_cpu:
95
    vram_state = CPU
96

97
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
98

99
100

current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
101
current_gpu_controlnets = []
102

103
104
105
model_accelerated = False


106
107
def unload_model():
    global current_loaded_model
108
    global model_accelerated
comfyanonymous's avatar
comfyanonymous committed
109
    global current_gpu_controlnets
110
111
    global vram_state

112
    if current_loaded_model is not None:
113
114
115
116
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

117
118
119
        #never unload models from GPU on high vram
        if vram_state != HIGH_VRAM:
            current_loaded_model.model.cpu()
120
121
        current_loaded_model.unpatch_model()
        current_loaded_model = None
122
123
124
125
126
127

    if vram_state != HIGH_VRAM:
        if len(current_gpu_controlnets) > 0:
            for n in current_gpu_controlnets:
                n.cpu()
            current_gpu_controlnets = []
128
129
130
131


def load_model_gpu(model):
    global current_loaded_model
132
133
    global vram_state
    global model_accelerated
134
    global xpu_available
135

136
137
138
139
140
141
142
143
144
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
    current_loaded_model = model
145
146
    if vram_state == CPU:
        pass
Yurii Mazurevich's avatar
Yurii Mazurevich committed
147
148
149
150
    elif vram_state == MPS:
        mps_device = torch.device("mps")
        real_model.to(mps_device)
        pass
151
    elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
152
        model_accelerated = False
153
154
155
156
        if xpu_available:
            real_model.to("xpu")
        else:
            real_model.cuda()
157
158
159
160
    else:
        if vram_state == NO_VRAM:
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        elif vram_state == LOW_VRAM:
161
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
comfyanonymous's avatar
comfyanonymous committed
162

163
        accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda")
164
        model_accelerated = True
165
    return current_loaded_model
166

comfyanonymous's avatar
comfyanonymous committed
167
168
def load_controlnet_gpu(models):
    global current_gpu_controlnets
169
    global vram_state
170
171
    if vram_state == CPU:
        return
172
173
174
175
176

    if vram_state == LOW_VRAM or vram_state == NO_VRAM:
        #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
        return

comfyanonymous's avatar
comfyanonymous committed
177
178
179
180
    for m in current_gpu_controlnets:
        if m not in models:
            m.cpu()

181
    device = get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
182
183
    current_gpu_controlnets = []
    for m in models:
184
        current_gpu_controlnets.append(m.to(device))
comfyanonymous's avatar
comfyanonymous committed
185

186

187
188
def load_if_low_vram(model):
    global vram_state
189
    global xpu_available
190
    if vram_state == LOW_VRAM or vram_state == NO_VRAM:
191
192
193
194
        if xpu_available:
            return model.to("xpu")
        else:
            return model.cuda()
195
196
197
198
199
200
201
202
    return model

def unload_if_low_vram(model):
    global vram_state
    if vram_state == LOW_VRAM or vram_state == NO_VRAM:
        return model.cpu()
    return model

203
def get_torch_device():
204
    global xpu_available
Yurii Mazurevich's avatar
Yurii Mazurevich committed
205
206
    if vram_state == MPS:
        return torch.device("mps")
207
208
209
    if vram_state == CPU:
        return torch.device("cpu")
    else:
210
211
212
213
        if xpu_available:
            return torch.device("xpu")
        else:
            return torch.cuda.current_device()
214
215
216
217
218

def get_autocast_device(dev):
    if hasattr(dev, 'type'):
        return dev.type
    return "cuda"
219

220

221
222
223
224
225
def xformers_enabled():
    if vram_state == CPU:
        return False
    return XFORMERS_IS_AVAILBLE

226
227
228
229
230
231
232
233
234
235
236
237
238

def xformers_enabled_vae():
    enabled = xformers_enabled()
    if not enabled:
        return False
    try:
        #0.0.18 has a bug where Nan is returned when inputs are too big (1152x1920 res images and above)
        if xformers.version.__version__ == "0.0.18":
            return False
    except:
        pass
    return enabled

239
240
241
def pytorch_attention_enabled():
    return ENABLE_PYTORCH_ATTENTION

242
def get_free_memory(dev=None, torch_free_too=False):
243
    global xpu_available
244
    if dev is None:
245
        dev = get_torch_device()
246

Yurii Mazurevich's avatar
Yurii Mazurevich committed
247
    if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
248
249
250
        mem_free_total = psutil.virtual_memory().available
        mem_free_torch = mem_free_total
    else:
251
252
253
254
255
256
257
258
259
260
        if xpu_available:
            mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
            mem_free_torch = mem_free_total
        else:
            stats = torch.cuda.memory_stats(dev)
            mem_active = stats['active_bytes.all.current']
            mem_reserved = stats['reserved_bytes.all.current']
            mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
            mem_free_torch = mem_reserved - mem_active
            mem_free_total = mem_free_cuda + mem_free_torch
261
262
263
264
265

    if torch_free_too:
        return (mem_free_total, mem_free_torch)
    else:
        return mem_free_total
266
267
268
269
270
271
272
273
274

def maximum_batch_area():
    global vram_state
    if vram_state == NO_VRAM:
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
    area = ((memory_free - 1024) * 0.9) / (0.6)
    return int(max(area, 0))
275
276
277
278
279

def cpu_mode():
    global vram_state
    return vram_state == CPU

Yurii Mazurevich's avatar
Yurii Mazurevich committed
280
281
282
283
def mps_mode():
    global vram_state
    return vram_state == MPS

284
def should_use_fp16():
285
286
    global xpu_available
    if cpu_mode() or mps_mode() or xpu_available:
287
288
289
290
291
        return False #TODO ?

    if torch.cuda.is_bf16_supported():
        return True

comfyanonymous's avatar
comfyanonymous committed
292
    props = torch.cuda.get_device_properties("cuda")
293
294
295
296
    if props.major < 7:
        return False

    #FP32 is faster on those cards?
297
    nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
298
299
300
301
302
303
    for x in nvidia_16_series:
        if x in props.name:
            return False

    return True

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#TODO: might be cleaner to put this somewhere else
import threading

class InterruptProcessingException(Exception):
    pass

interrupt_processing_mutex = threading.RLock()

interrupt_processing = False
def interrupt_current_processing(value=True):
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        interrupt_processing = value

def processing_interrupted():
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        return interrupt_processing

def throw_exception_if_processing_interrupted():
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        if interrupt_processing:
            interrupt_processing = False
            raise InterruptProcessingException()