Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
eb448dd8
Commit
eb448dd8
authored
May 30, 2023
by
comfyanonymous
Browse files
Auto load model in lowvram if not enough memory.
parent
560e9f7a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
19 deletions
+45
-19
comfy/model_management.py
comfy/model_management.py
+29
-17
comfy/sd.py
comfy/sd.py
+16
-2
No files found.
comfy/model_management.py
View file @
eb448dd8
...
@@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM
...
@@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM
set_vram_to
=
VRAMState
.
NORMAL_VRAM
set_vram_to
=
VRAMState
.
NORMAL_VRAM
total_vram
=
0
total_vram
=
0
total_vram_available_mb
=
-
1
accelerate_en
able
d
=
Fals
e
lowvram_avail
able
=
Tru
e
xpu_available
=
False
xpu_available
=
False
directml_enabled
=
False
directml_enabled
=
False
...
@@ -31,11 +30,12 @@ if args.directml is not None:
...
@@ -31,11 +30,12 @@ if args.directml is not None:
directml_device
=
torch_directml
.
device
(
device_index
)
directml_device
=
torch_directml
.
device
(
device_index
)
print
(
"Using directml with device:"
,
torch_directml
.
device_name
(
device_index
))
print
(
"Using directml with device:"
,
torch_directml
.
device_name
(
device_index
))
# torch_directml.disable_tiled_resources(True)
# torch_directml.disable_tiled_resources(True)
lowvram_available
=
False
#TODO: need to find a way to get free memory in directml before this can be enabled by default.
try
:
try
:
import
torch
import
torch
if
directml_enabled
:
if
directml_enabled
:
total_vram
=
4097
#TODO
pass
#TODO
else
:
else
:
try
:
try
:
import
intel_extension_for_pytorch
as
ipex
import
intel_extension_for_pytorch
as
ipex
...
@@ -46,7 +46,7 @@ try:
...
@@ -46,7 +46,7 @@ try:
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_vram
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())[
1
]
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
total_ram
=
psutil
.
virtual_memory
().
total
/
(
1024
*
1024
)
if
not
args
.
normalvram
and
not
args
.
cpu
:
if
not
args
.
normalvram
and
not
args
.
cpu
:
if
total_vram
<=
4096
:
if
lowvram_available
and
total_vram
<=
4096
:
print
(
"Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram"
)
print
(
"Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram"
)
set_vram_to
=
VRAMState
.
LOW_VRAM
set_vram_to
=
VRAMState
.
LOW_VRAM
elif
total_vram
>
total_ram
*
1.1
and
total_vram
>
14336
:
elif
total_vram
>
total_ram
*
1.1
and
total_vram
>
14336
:
...
@@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION:
...
@@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION:
if
args
.
lowvram
:
if
args
.
lowvram
:
set_vram_to
=
VRAMState
.
LOW_VRAM
set_vram_to
=
VRAMState
.
LOW_VRAM
lowvram_available
=
True
elif
args
.
novram
:
elif
args
.
novram
:
set_vram_to
=
VRAMState
.
NO_VRAM
set_vram_to
=
VRAMState
.
NO_VRAM
elif
args
.
highvram
:
elif
args
.
highvram
:
...
@@ -103,18 +104,18 @@ if args.force_fp32:
...
@@ -103,18 +104,18 @@ if args.force_fp32:
FORCE_FP32
=
True
FORCE_FP32
=
True
if
set_vram_to
in
(
VRAMState
.
LOW_VRAM
,
VRAMState
.
NO_VRAM
):
if
lowvram_available
:
try
:
try
:
import
accelerate
import
accelerate
accelerate_enabled
=
True
if
set_vram_to
in
(
VRAMState
.
LOW_VRAM
,
VRAMState
.
NO_VRAM
):
vram_state
=
set_vram_to
vram_state
=
set_vram_to
except
Exception
as
e
:
except
Exception
as
e
:
import
traceback
import
traceback
print
(
traceback
.
format_exc
())
print
(
traceback
.
format_exc
())
print
(
"ERROR: COULD NOT ENABLE LOW VRAM MODE."
)
print
(
"ERROR: LOW VRAM MODE NEEDS accelerate."
)
lowvram_available
=
False
total_vram_available_mb
=
(
total_vram
-
1024
)
//
2
total_vram_available_mb
=
int
(
max
(
256
,
total_vram_available_mb
))
try
:
try
:
if
torch
.
backends
.
mps
.
is_available
():
if
torch
.
backends
.
mps
.
is_available
():
...
@@ -199,22 +200,33 @@ def load_model_gpu(model):
...
@@ -199,22 +200,33 @@ def load_model_gpu(model):
model
.
unpatch_model
()
model
.
unpatch_model
()
raise
e
raise
e
model
.
model_patches_to
(
get_torch_device
())
torch_dev
=
get_torch_device
()
model
.
model_patches_to
(
torch_dev
)
vram_set_state
=
vram_state
if
lowvram_available
and
(
vram_set_state
==
VRAMState
.
LOW_VRAM
or
vram_set_state
==
VRAMState
.
NORMAL_VRAM
):
model_size
=
model
.
model_size
()
current_free_mem
=
get_free_memory
(
torch_dev
)
lowvram_model_memory
=
int
(
max
(
256
*
(
1024
*
1024
),
(
current_free_mem
-
1024
*
(
1024
*
1024
))
/
1.2
))
if
model_size
>
(
current_free_mem
-
(
512
*
1024
*
1024
)):
#only switch to lowvram if really necessary
vram_set_state
=
VRAMState
.
LOW_VRAM
current_loaded_model
=
model
current_loaded_model
=
model
if
vram_state
==
VRAMState
.
CPU
:
if
vram_set_state
==
VRAMState
.
CPU
:
pass
pass
elif
vram_state
==
VRAMState
.
MPS
:
elif
vram_
set_
state
==
VRAMState
.
MPS
:
mps_device
=
torch
.
device
(
"mps"
)
mps_device
=
torch
.
device
(
"mps"
)
real_model
.
to
(
mps_device
)
real_model
.
to
(
mps_device
)
pass
pass
elif
vram_state
==
VRAMState
.
NORMAL_VRAM
or
vram_state
==
VRAMState
.
HIGH_VRAM
:
elif
vram_
set_
state
==
VRAMState
.
NORMAL_VRAM
or
vram_
set_
state
==
VRAMState
.
HIGH_VRAM
:
model_accelerated
=
False
model_accelerated
=
False
real_model
.
to
(
get_torch_device
())
real_model
.
to
(
get_torch_device
())
else
:
else
:
if
vram_state
==
VRAMState
.
NO_VRAM
:
if
vram_
set_
state
==
VRAMState
.
NO_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
elif
vram_state
==
VRAMState
.
LOW_VRAM
:
elif
vram_
set_
state
==
VRAMState
.
LOW_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
total_vram_available_mb
),
"cpu"
:
"16GiB"
})
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"{}MiB"
.
format
(
lowvram_model_memory
//
(
1024
*
1024
)
),
"cpu"
:
"16GiB"
})
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
get_torch_device
())
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
get_torch_device
())
model_accelerated
=
True
model_accelerated
=
True
...
...
comfy/sd.py
View file @
eb448dd8
...
@@ -286,15 +286,29 @@ def model_lora_keys(model, key_map={}):
...
@@ -286,15 +286,29 @@ def model_lora_keys(model, key_map={}):
return
key_map
return
key_map
class
ModelPatcher
:
class
ModelPatcher
:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
,
size
=
0
):
self
.
size
=
size
self
.
model
=
model
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
[]
self
.
backup
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_options
=
{
"transformer_options"
:{}}
self
.
model_size
()
def
model_size
(
self
):
if
self
.
size
>
0
:
return
self
.
size
model_sd
=
self
.
model
.
state_dict
()
size
=
0
for
k
in
model_sd
:
t
=
model_sd
[
k
]
size
+=
t
.
nelement
()
*
t
.
element_size
()
self
.
size
=
size
return
size
def
clone
(
self
):
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
)
n
=
ModelPatcher
(
self
.
model
,
self
.
size
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
self
.
patches
[:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
return
n
return
n
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment