model_management.py 11.2 KB
Newer Older
1
2
import psutil
from enum import Enum
comfyanonymous's avatar
comfyanonymous committed
3
from .cli_args import args
4

5
6
7
8
9
10
11
class VRAMState(Enum):
    CPU = 0
    NO_VRAM = 1
    LOW_VRAM = 2
    NORMAL_VRAM = 3
    HIGH_VRAM = 4
    MPS = 5
12

13
14
15
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
16

17
total_vram = 0
18
19
total_vram_available_mb = -1

20
accelerate_enabled = False
藍+85CD's avatar
藍+85CD committed
21
xpu_available = False
22

23
directml_enabled = False
24
if args.directml is not None:
25
26
    import torch_directml
    directml_enabled = True
27
28
29
30
31
32
    device_index = args.directml
    if device_index < 0:
        directml_device = torch_directml.device()
    else:
        directml_device = torch_directml.device(device_index)
    print("Using directml with device:", torch_directml.device_name(device_index))
33
34
    # torch_directml.disable_tiled_resources(True)

35
36
try:
    import torch
37
38
39
40
41
42
43
44
45
46
    if directml_enabled:
        total_vram = 4097 #TODO
    else:
        try:
            import intel_extension_for_pytorch as ipex
            if torch.xpu.is_available():
                xpu_available = True
                total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024)
        except:
            total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
47
    total_ram = psutil.virtual_memory().total / (1024 * 1024)
48
    if not args.normalvram and not args.cpu:
49
50
        if total_vram <= 4096:
            print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
51
            set_vram_to = VRAMState.LOW_VRAM
comfyanonymous's avatar
comfyanonymous committed
52
        elif total_vram > total_ram * 1.1 and total_vram > 14336:
53
            print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
54
            vram_state = VRAMState.HIGH_VRAM
55
56
57
except:
    pass

58
59
60
61
62
try:
    OOM_EXCEPTION = torch.cuda.OutOfMemoryError
except:
    OOM_EXCEPTION = Exception

63
64
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
65
66
if args.disable_xformers:
    XFORMERS_IS_AVAILABLE = False
67
68
69
70
else:
    try:
        import xformers
        import xformers.ops
71
        XFORMERS_IS_AVAILABLE = True
72
73
74
75
76
77
78
79
80
81
82
        try:
            XFORMERS_VERSION = xformers.version.__version__
            print("xformers version:", XFORMERS_VERSION)
            if XFORMERS_VERSION.startswith("0.0.18"):
                print()
                print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
                print("Please downgrade or upgrade xformers to a different version.")
                print()
                XFORMERS_ENABLED_VAE = False
        except:
            pass
83
    except:
84
        XFORMERS_IS_AVAILABLE = False
85

86
87
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if ENABLE_PYTORCH_ATTENTION:
88
89
90
    torch.backends.cuda.enable_math_sdp(True)
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(True)
91
    XFORMERS_IS_AVAILABLE = False
92

93
94
95
96
97
98
if args.lowvram:
    set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
    set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
    vram_state = VRAMState.HIGH_VRAM
99

100
101
102
103
104
FORCE_FP32 = False
if args.force_fp32:
    print("Forcing FP32, if this improves things please report it.")
    FORCE_FP32 = True

105

106
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
107
108
109
110
111
112
113
114
    try:
        import accelerate
        accelerate_enabled = True
        vram_state = set_vram_to
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        print("ERROR: COULD NOT ENABLE LOW VRAM MODE.")
115
116

    total_vram_available_mb = (total_vram - 1024) // 2
117
    total_vram_available_mb = int(max(256, total_vram_available_mb))
118

119
120
try:
    if torch.backends.mps.is_available():
121
        vram_state = VRAMState.MPS
122
123
124
except:
    pass

125
126
if args.cpu:
    vram_state = VRAMState.CPU
127

128
print(f"Set vram state to: {vram_state.name}")
129

130
131

current_loaded_model = None
comfyanonymous's avatar
comfyanonymous committed
132
current_gpu_controlnets = []
133

134
135
136
model_accelerated = False


137
138
def unload_model():
    global current_loaded_model
139
    global model_accelerated
comfyanonymous's avatar
comfyanonymous committed
140
    global current_gpu_controlnets
141
142
    global vram_state

143
    if current_loaded_model is not None:
144
145
146
147
        if model_accelerated:
            accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
            model_accelerated = False

148
        #never unload models from GPU on high vram
149
        if vram_state != VRAMState.HIGH_VRAM:
150
            current_loaded_model.model.cpu()
151
            current_loaded_model.model_patches_to("cpu")
152
153
        current_loaded_model.unpatch_model()
        current_loaded_model = None
154

155
    if vram_state != VRAMState.HIGH_VRAM:
156
157
158
159
        if len(current_gpu_controlnets) > 0:
            for n in current_gpu_controlnets:
                n.cpu()
            current_gpu_controlnets = []
160
161
162
163


def load_model_gpu(model):
    global current_loaded_model
164
165
166
    global vram_state
    global model_accelerated

167
168
169
170
171
172
173
174
    if model is current_loaded_model:
        return
    unload_model()
    try:
        real_model = model.patch_model()
    except Exception as e:
        model.unpatch_model()
        raise e
175
176

    model.model_patches_to(get_torch_device())
177
    current_loaded_model = model
178
    if vram_state == VRAMState.CPU:
179
        pass
180
    elif vram_state == VRAMState.MPS:
Yurii Mazurevich's avatar
Yurii Mazurevich committed
181
182
183
        mps_device = torch.device("mps")
        real_model.to(mps_device)
        pass
184
    elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
185
        model_accelerated = False
comfyanonymous's avatar
comfyanonymous committed
186
        real_model.to(get_torch_device())
187
    else:
188
        if vram_state == VRAMState.NO_VRAM:
189
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
190
        elif vram_state == VRAMState.LOW_VRAM:
191
            device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
comfyanonymous's avatar
comfyanonymous committed
192

comfyanonymous's avatar
comfyanonymous committed
193
        accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device())
194
        model_accelerated = True
195
    return current_loaded_model
196

197
def load_controlnet_gpu(control_models):
comfyanonymous's avatar
comfyanonymous committed
198
    global current_gpu_controlnets
199
    global vram_state
200
    if vram_state == VRAMState.CPU:
201
        return
202

203
    if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
204
205
206
        #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
        return

207
208
209
210
    models = []
    for m in control_models:
        models += m.get_models()

comfyanonymous's avatar
comfyanonymous committed
211
212
213
214
    for m in current_gpu_controlnets:
        if m not in models:
            m.cpu()

215
    device = get_torch_device()
comfyanonymous's avatar
comfyanonymous committed
216
217
    current_gpu_controlnets = []
    for m in models:
218
        current_gpu_controlnets.append(m.to(device))
comfyanonymous's avatar
comfyanonymous committed
219

220

221
222
def load_if_low_vram(model):
    global vram_state
223
    if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
comfyanonymous's avatar
comfyanonymous committed
224
        return model.to(get_torch_device())
225
226
227
228
    return model

def unload_if_low_vram(model):
    global vram_state
229
    if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
230
231
232
        return model.cpu()
    return model

233
def get_torch_device():
234
    global xpu_available
235
236
    global directml_enabled
    if directml_enabled:
237
238
        global directml_device
        return directml_device
239
    if vram_state == VRAMState.MPS:
Yurii Mazurevich's avatar
Yurii Mazurevich committed
240
        return torch.device("mps")
241
    if vram_state == VRAMState.CPU:
242
243
        return torch.device("cpu")
    else:
244
245
246
247
        if xpu_available:
            return torch.device("xpu")
        else:
            return torch.cuda.current_device()
248
249
250
251
252

def get_autocast_device(dev):
    if hasattr(dev, 'type'):
        return dev.type
    return "cuda"
253

254

255
def xformers_enabled():
256
257
    global xpu_available
    global directml_enabled
258
    if vram_state == VRAMState.CPU:
259
        return False
260
261
262
263
    if xpu_available:
        return False
    if directml_enabled:
        return False
264
    return XFORMERS_IS_AVAILABLE
265

266
267
268
269
270

def xformers_enabled_vae():
    enabled = xformers_enabled()
    if not enabled:
        return False
271
272

    return XFORMERS_ENABLED_VAE
273

274
275
276
def pytorch_attention_enabled():
    return ENABLE_PYTORCH_ATTENTION

277
def get_free_memory(dev=None, torch_free_too=False):
278
    global xpu_available
279
    global directml_enabled
280
    if dev is None:
281
        dev = get_torch_device()
282

Yurii Mazurevich's avatar
Yurii Mazurevich committed
283
    if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
284
285
286
        mem_free_total = psutil.virtual_memory().available
        mem_free_torch = mem_free_total
    else:
287
288
289
290
        if directml_enabled:
            mem_free_total = 1024 * 1024 * 1024 #TODO
            mem_free_torch = mem_free_total
        elif xpu_available:
291
292
293
294
295
296
297
298
299
            mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
            mem_free_torch = mem_free_total
        else:
            stats = torch.cuda.memory_stats(dev)
            mem_active = stats['active_bytes.all.current']
            mem_reserved = stats['reserved_bytes.all.current']
            mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
            mem_free_torch = mem_reserved - mem_active
            mem_free_total = mem_free_cuda + mem_free_torch
300
301
302
303
304

    if torch_free_too:
        return (mem_free_total, mem_free_torch)
    else:
        return mem_free_total
305
306
307

def maximum_batch_area():
    global vram_state
308
    if vram_state == VRAMState.NO_VRAM:
309
310
311
312
313
        return 0

    memory_free = get_free_memory() / (1024 * 1024)
    area = ((memory_free - 1024) * 0.9) / (0.6)
    return int(max(area, 0))
314
315
316

def cpu_mode():
    global vram_state
317
    return vram_state == VRAMState.CPU
318

Yurii Mazurevich's avatar
Yurii Mazurevich committed
319
320
def mps_mode():
    global vram_state
321
    return vram_state == VRAMState.MPS
Yurii Mazurevich's avatar
Yurii Mazurevich committed
322

323
def should_use_fp16():
324
    global xpu_available
325
326
    global directml_enabled

327
328
329
    if FORCE_FP32:
        return False

330
331
332
    if directml_enabled:
        return False

333
    if cpu_mode() or mps_mode() or xpu_available:
334
335
336
337
338
        return False #TODO ?

    if torch.cuda.is_bf16_supported():
        return True

comfyanonymous's avatar
comfyanonymous committed
339
    props = torch.cuda.get_device_properties("cuda")
340
341
342
343
    if props.major < 7:
        return False

    #FP32 is faster on those cards?
344
    nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600"]
345
346
347
348
349
350
    for x in nvidia_16_series:
        if x in props.name:
            return False

    return True

351
352
353
354
355
356
357
358
359
def soft_empty_cache():
    global xpu_available
    if xpu_available:
        torch.xpu.empty_cache()
    elif torch.cuda.is_available():
        if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
#TODO: might be cleaner to put this somewhere else
import threading

class InterruptProcessingException(Exception):
    pass

interrupt_processing_mutex = threading.RLock()

interrupt_processing = False
def interrupt_current_processing(value=True):
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        interrupt_processing = value

def processing_interrupted():
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        return interrupt_processing

def throw_exception_if_processing_interrupted():
    global interrupt_processing
    global interrupt_processing_mutex
    with interrupt_processing_mutex:
        if interrupt_processing:
            interrupt_processing = False
            raise InterruptProcessingException()