model_management.py 17.6 KB
Newer Older
1
2
import psutil
from enum import Enum
comfyanonymous's avatar
comfyanonymous committed
3
from comfy.cli_args import args
4
import torch
5

6
class VRAMState(Enum):
7
8
    DISABLED = 0    #No vram present: no need to move models to vram
    NO_VRAM = 1     #Very low vram: enable all the options to save vram
9
10
11
    LOW_VRAM = 2
    NORMAL_VRAM = 3
    HIGH_VRAM = 4
12
    SHARED = 5      #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
13
14
15
16
17

class CPUState(Enum):
    GPU = 0
    CPU = 1
    MPS = 2
18

19
20
21
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
22
cpu_state = CPUState.GPU
23

24
total_vram = 0
25

26
lowvram_available = True
藍+85CD's avatar
藍+85CD committed
27
xpu_available = False
28

29
directml_enabled = False
30
if args.directml is not None:
31
32
    import torch_directml
    directml_enabled = True
33
34
35
36
37
38
    device_index = args.directml
    if device_index < 0:
        directml_device = torch_directml.device()
    else:
        directml_device = torch_directml.device(device_index)
    print("Using directml with device:", torch_directml.device_name(device_index))
39
    # torch_directml.disable_tiled_resources(True)
40
    lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
41

42
try:
43
44
45
    import intel_extension_for_pytorch as ipex
    if torch.xpu.is_available():
        xpu_available = True
46
47
48
except:
    pass

49
50
51
try:
    if torch.backends.mps.is_available():
        cpu_state = CPUState.MPS
KarryCharon's avatar
KarryCharon committed
52
        import torch.mps
53
54
55
56
57
58
except:
    pass

if args.cpu:
    cpu_state = CPUState.CPU

59
60
61
def get_torch_device():
    global xpu_available
    global directml_enabled
62
    global cpu_state
63
64
65
    if directml_enabled:
        global directml_device
        return directml_device
66
    if cpu_state == CPUState.MPS:
67
        return torch.device("mps")
68
    if cpu_state == CPUState.CPU:
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        return torch.device("cpu")
    else:
        if xpu_available:
            return torch.device("xpu")
        else:
            return torch.device(torch.cuda.current_device())

def get_total_memory(dev=None, torch_total_too=False):
    global xpu_available
    global directml_enabled
    if dev is None:
        dev = get_torch_device()

    if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
        mem_total = psutil.virtual_memory().total
        mem_total_torch = mem_total
    else:
        if directml_enabled:
            mem_total = 1024 * 1024 * 1024 #TODO
            mem_total_torch = mem_total
        elif xpu_available:
            mem_total = torch.xpu.get_device_properties(dev).total_memory
            mem_total_torch = mem_total
        else:
            stats = torch.cuda.memory_stats(dev)
            mem_reserved = stats['reserved_bytes.all.current']
            _, mem_total_cuda = torch.cuda.mem_get_info(dev)
            mem_total_torch = mem_reserved
            mem_total = mem_total_cuda

    if torch_total_too:
        return (mem_total, mem_total_torch)
    else:
        return mem_total

total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
    if lowvram_available and total_vram <= 4096:
        print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
        set_vram_to = VRAMState.LOW_VRAM
    elif total_vram > total_ram * 1.1 and total_vram > 14336:
        print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
        vram_state = VRAMState.HIGH_VRAM

115
116
117
118
119
try:
    OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
    OOM_EXCEPTION = Exception

120
121
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
122
123
if args.disable_xformers:
    XFORMERS_IS_AVAILABLE = False
124
125
126
127
else:
    try:
        import xformers
        import xformers.ops
128
        XFORMERS_IS_AVAILABLE = True
129
130
131
132
133
134
135
136
137
138
139
        try:
            XFORMERS_VERSION = xformers.version.__version__
            print("xformers version:", XFORMERS_VERSION)
            if XFORMERS_VERSION.startswith("0.0.18"):
                print()
                print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
                print("Please downgrade or upgrade xformers to a different version.")
                print()
                XFORMERS_ENABLED_VAE = False
        except:
            pass
140
    except:
141
        XFORMERS_IS_AVAILABLE = False
142

143
144
145
146
147
148
def is_nvidia():
    global cpu_state
    if cpu_state == CPUState.GPU:
        if torch.version.cuda:
            return True

149
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
150
151
152
153
154
155
156
157
158
159

if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
    try:
        if is_nvidia():
            torch_version = torch.version.__version__
            if int(torch_version[0]) >= 2:
                ENABLE_PYTORCH_ATTENTION = True
    except:
        pass

160
if ENABLE_PYTORCH_ATTENTION:
161
162
163
    torch.backends.cuda.enable_math_sdp(True)
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(True)
164
    XFORMERS_IS_AVAILABLE = False
165

166
167
if args.lowvram:
    set_vram_to = VRAMState.LOW_VRAM
168
    lowvram_available = True
169
170
elif args.novram:
    set_vram_to = VRAMState.NO_VRAM
171
elif args.highvram or args.gpu_only:
172
    vram_state = VRAMState.HIGH_VRAM
173

174
FORCE_FP32 = False
175
FORCE_FP16 = False
176
177
178
179
if args.force_fp32:
    print("Forcing FP32, if this improves things please report it.")
    FORCE_FP32 = True

180
181
182
183
if args.force_fp16:
    print("Forcing FP16.")
    FORCE_FP16 = True

184
if lowvram_available:
185
186
    try:
        import accelerate
187
188
        if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
            vram_state = set_vram_to
189
190
191
    except Exception as e:
        import traceback
        print(traceback.format_exc())
192
193
        print("ERROR: LOW VRAM MODE NEEDS accelerate.")
        lowvram_available = False
194

195

196
197
if cpu_state != CPUState.GPU:
    vram_state = VRAMState.DISABLED
198

199
200
if cpu_state == CPUState.MPS:
    vram_state = VRAMState.SHARED
201

202
print(f"Set vram state to: {vram_state.name}")
203

204

205
206
def get_torch_device_name(device):
    if hasattr(device, 'type'):
207
        if device.type == "cuda":
208
209
210
211
212
            try:
                allocator_backend = torch.cuda.get_allocator_backend()
            except:
                allocator_backend = ""
            return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
213
214
215
216
        else:
            return "{}".format(device.type)
    else:
        return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
217
218

try:
219
    print("Device:", get_torch_device_name(get_torch_device()))
220
221
222
except:
    print("Could not pick default device.")

223
224

current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
225
current_gpu_controlnets = []
226

227
228
229
model_accelerated = False


230
231
def unload_model():
    global current_loaded_model
232
    global model_accelerated
comfyanonymous's avatar
comfyanonymous committed
233
    global current_gpu_controlnets
234
235
    global vram_state

236
    if current_loaded_model is not None:
237
238
239
240
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

241
        current_loaded_model.unpatch_model()
242
243
        current_loaded_model.model.to(current_loaded_model.offload_device)
        current_loaded_model.model_patches_to(current_loaded_model.offload_device)
244
        current_loaded_model = None
245
246
        if vram_state != VRAMState.HIGH_VRAM:
            soft_empty_cache()
247

248
    if vram_state != VRAMState.HIGH_VRAM:
249
250
251
252
        if len(current_gpu_controlnets) > 0:
            for n in current_gpu_controlnets:
                n.cpu()
            current_gpu_controlnets = []
253

254
255
def minimum_inference_memory():
    return (768 * 1024 * 1024)
256
257
258

def load_model_gpu(model):
    global current_loaded_model
259
260
261
    global vram_state
    global model_accelerated

262
263
264
    if model is current_loaded_model:
        return
    unload_model()
265

266
    torch_dev = model.load_device
267
    model.model_patches_to(torch_dev)
268
    model.model_patches_to(model.model_dtype())
comfyanonymous's avatar
comfyanonymous committed
269
    current_loaded_model = model
270

271
272
273
274
275
    if is_device_cpu(torch_dev):
        vram_set_state = VRAMState.DISABLED
    else:
        vram_set_state = vram_state

276
277
278
    if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
        model_size = model.model_size()
        current_free_mem = get_free_memory(torch_dev)
279
        lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
280
        if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
281
282
            vram_set_state = VRAMState.LOW_VRAM

comfyanonymous's avatar
comfyanonymous committed
283
    real_model = model.model
284
    patch_model_to = None
285
    if vram_set_state == VRAMState.DISABLED:
Yurii Mazurevich's avatar
Yurii Mazurevich committed
286
        pass
287
    elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
288
        model_accelerated = False
289
        patch_model_to = torch_dev
comfyanonymous's avatar
comfyanonymous committed
290
291

    try:
292
        real_model = model.patch_model(device_to=patch_model_to)
comfyanonymous's avatar
comfyanonymous committed
293
294
295
296
297
    except Exception as e:
        model.unpatch_model()
        unload_model()
        raise e

298
299
300
    if patch_model_to is not None:
        real_model.to(torch_dev)

301
302
303
304
305
306
307
308
309
    if vram_set_state == VRAMState.NO_VRAM:
        device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
        accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
        model_accelerated = True
    elif vram_set_state == VRAMState.LOW_VRAM:
        device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
        accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
        model_accelerated = True

310
    return current_loaded_model
311

312
def load_controlnet_gpu(control_models):
comfyanonymous's avatar
comfyanonymous committed
313
    global current_gpu_controlnets
314
    global vram_state
315
    if vram_state == VRAMState.DISABLED:
316
        return
317

318
    if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
319
320
321
        for m in control_models:
            if hasattr(m, 'set_lowvram'):
                m.set_lowvram(True)
322
323
324
        #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
        return

325
326
327
328
    models = []
    for m in control_models:
        models += m.get_models()

comfyanonymous's avatar
comfyanonymous committed
329
330
331
332
    for m in current_gpu_controlnets:
        if m not in models:
            m.cpu()

333
    device = get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
334
335
    current_gpu_controlnets = []
    for m in models:
336
        current_gpu_controlnets.append(m.to(device))
comfyanonymous's avatar
comfyanonymous committed
337

338

339
340
def load_if_low_vram(model):
    global vram_state
341
    if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
comfyanonymous's avatar
comfyanonymous committed
342
        return model.to(get_torch_device())
343
344
345
346
    return model

def unload_if_low_vram(model):
    global vram_state
347
    if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
348
349
350
        return model.cpu()
    return model

351
def unet_offload_device():
comfyanonymous's avatar
comfyanonymous committed
352
    if vram_state == VRAMState.HIGH_VRAM:
353
354
355
356
        return get_torch_device()
    else:
        return torch.device("cpu")

357
def text_encoder_offload_device():
comfyanonymous's avatar
comfyanonymous committed
358
    if args.gpu_only:
359
360
361
362
        return get_torch_device()
    else:
        return torch.device("cpu")

363
def text_encoder_device():
comfyanonymous's avatar
comfyanonymous committed
364
    if args.gpu_only:
365
        return get_torch_device()
366
    elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
367
368
        #NOTE: on a Ryzen 5 7600X with 4080 it's faster to shift to GPU
        if torch.get_num_threads() < 8: #leaving the text encoder on the CPU is faster than shifting it if the CPU is fast enough.
369
370
371
            return get_torch_device()
        else:
            return torch.device("cpu")
372
373
374
    else:
        return torch.device("cpu")

375
376
377
378
def vae_device():
    return get_torch_device()

def vae_offload_device():
comfyanonymous's avatar
comfyanonymous committed
379
    if args.gpu_only:
380
381
382
383
        return get_torch_device()
    else:
        return torch.device("cpu")

384
385
386
387
388
389
390
391
def vae_dtype():
    if args.fp16_vae:
        return torch.float16
    elif args.bf16_vae:
        return torch.bfloat16
    else:
        return torch.float32

392
393
394
395
def get_autocast_device(dev):
    if hasattr(dev, 'type'):
        return dev.type
    return "cuda"
396

397

398
def xformers_enabled():
399
400
    global xpu_available
    global directml_enabled
401
402
    global cpu_state
    if cpu_state != CPUState.GPU:
403
        return False
404
405
406
407
    if xpu_available:
        return False
    if directml_enabled:
        return False
408
    return XFORMERS_IS_AVAILABLE
409

410
411
412
413
414

def xformers_enabled_vae():
    enabled = xformers_enabled()
    if not enabled:
        return False
415
416

    return XFORMERS_ENABLED_VAE
417

418
def pytorch_attention_enabled():
419
    global ENABLE_PYTORCH_ATTENTION
420
421
    return ENABLE_PYTORCH_ATTENTION

422
423
424
425
def pytorch_attention_flash_attention():
    global ENABLE_PYTORCH_ATTENTION
    if ENABLE_PYTORCH_ATTENTION:
        #TODO: more reliable way of checking for flash attention?
426
        if is_nvidia(): #pytorch flash attention only works on Nvidia
427
428
429
            return True
    return False

430
def get_free_memory(dev=None, torch_free_too=False):
431
    global xpu_available
432
    global directml_enabled
433
    if dev is None:
434
        dev = get_torch_device()
435

Yurii Mazurevich's avatar
Yurii Mazurevich committed
436
    if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
437
438
439
        mem_free_total = psutil.virtual_memory().available
        mem_free_torch = mem_free_total
    else:
440
441
442
443
        if directml_enabled:
            mem_free_total = 1024 * 1024 * 1024 #TODO
            mem_free_torch = mem_free_total
        elif xpu_available:
444
445
446
447
448
449
450
451
452
            mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
            mem_free_torch = mem_free_total
        else:
            stats = torch.cuda.memory_stats(dev)
            mem_active = stats['active_bytes.all.current']
            mem_reserved = stats['reserved_bytes.all.current']
            mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
            mem_free_torch = mem_reserved - mem_active
            mem_free_total = mem_free_cuda + mem_free_torch
453
454
455
456
457

    if torch_free_too:
        return (mem_free_total, mem_free_torch)
    else:
        return mem_free_total
458
459
460

def maximum_batch_area():
    global vram_state
461
    if vram_state == VRAMState.NO_VRAM:
462
463
464
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
465
    if xformers_enabled() or pytorch_attention_flash_attention():
466
        #TODO: this needs to be tweaked
467
        area = 20 * memory_free
468
469
470
    else:
        #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
        area = ((memory_free - 1024) * 0.9) / (0.6)
471
    return int(max(area, 0))
472
473

def cpu_mode():
474
475
    global cpu_state
    return cpu_state == CPUState.CPU
476

Yurii Mazurevich's avatar
Yurii Mazurevich committed
477
def mps_mode():
478
479
    global cpu_state
    return cpu_state == CPUState.MPS
Yurii Mazurevich's avatar
Yurii Mazurevich committed
480

481
482
def is_device_cpu(device):
    if hasattr(device, 'type'):
comfyanonymous's avatar
comfyanonymous committed
483
484
485
486
487
488
489
        if (device.type == 'cpu'):
            return True
    return False

def is_device_mps(device):
    if hasattr(device, 'type'):
        if (device.type == 'mps'):
490
491
492
            return True
    return False

493
def should_use_fp16(device=None, model_params=0):
494
    global xpu_available
495
496
    global directml_enabled

497
498
499
    if FORCE_FP16:
        return True

500
    if device is not None: #TODO
comfyanonymous's avatar
comfyanonymous committed
501
        if is_device_cpu(device) or is_device_mps(device):
502
            return False
503

504
505
506
    if FORCE_FP32:
        return False

507
508
509
    if directml_enabled:
        return False

510
    if cpu_mode() or mps_mode() or xpu_available:
511
512
513
514
515
        return False #TODO ?

    if torch.cuda.is_bf16_supported():
        return True

comfyanonymous's avatar
comfyanonymous committed
516
    props = torch.cuda.get_device_properties("cuda")
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    if props.major < 6:
        return False

    fp16_works = False
    #FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
    #when the model doesn't actually fit on the card
    #TODO: actually test if GP106 and others have the same type of behavior
    nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050"]
    for x in nvidia_10_series:
        if x in props.name.lower():
            fp16_works = True

    if fp16_works:
        free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
        if model_params * 4 > free_model_memory:
            return True

534
535
536
    if props.major < 7:
        return False

537
    #FP16 is just broken on these cards
538
    nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450"]
539
540
541
542
543
544
    for x in nvidia_16_series:
        if x in props.name:
            return False

    return True

545
546
def soft_empty_cache():
    global xpu_available
547
548
    global cpu_state
    if cpu_state == CPUState.MPS:
comfyanonymous's avatar
comfyanonymous committed
549
550
        torch.mps.empty_cache()
    elif xpu_available:
551
552
        torch.xpu.empty_cache()
    elif torch.cuda.is_available():
553
        if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
554
555
556
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
#TODO: might be cleaner to put this somewhere else
import threading

class InterruptProcessingException(Exception):
    pass

interrupt_processing_mutex = threading.RLock()

interrupt_processing = False
def interrupt_current_processing(value=True):
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        interrupt_processing = value

def processing_interrupted():
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        return interrupt_processing

def throw_exception_if_processing_interrupted():
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        if interrupt_processing:
            interrupt_processing = False
            raise InterruptProcessingException()