Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
51dde87e
"examples/run_bert_squad.py" did not exist on "01ff4f82ba810ea9032e81fdbbfa1a6ff28c3379"
Commit
51dde87e
authored
Aug 24, 2023
by
comfyanonymous
Browse files
Try to free enough vram for control lora inference.
parent
e3d0a9a4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
18 deletions
+30
-18
comfy/model_management.py
comfy/model_management.py
+7
-5
comfy/sample.py
comfy/sample.py
+6
-4
comfy/sd.py
comfy/sd.py
+10
-9
comfy/utils.py
comfy/utils.py
+7
-0
No files found.
comfy/model_management.py
View file @
51dde87e
...
@@ -394,6 +394,12 @@ def cleanup_models():
...
@@ -394,6 +394,12 @@ def cleanup_models():
x
.
model_unload
()
x
.
model_unload
()
del
x
del
x
def
dtype_size
(
dtype
):
dtype_size
=
4
if
dtype
==
torch
.
float16
or
dtype
==
torch
.
bfloat16
:
dtype_size
=
2
return
dtype_size
def
unet_offload_device
():
def
unet_offload_device
():
if
vram_state
==
VRAMState
.
HIGH_VRAM
:
if
vram_state
==
VRAMState
.
HIGH_VRAM
:
return
get_torch_device
()
return
get_torch_device
()
...
@@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype):
...
@@ -409,11 +415,7 @@ def unet_inital_load_device(parameters, dtype):
if
DISABLE_SMART_MEMORY
:
if
DISABLE_SMART_MEMORY
:
return
cpu_dev
return
cpu_dev
dtype_size
=
4
model_size
=
dtype_size
(
dtype
)
*
parameters
if
dtype
==
torch
.
float16
or
dtype
==
torch
.
bfloat16
:
dtype_size
=
2
model_size
=
dtype_size
*
parameters
mem_dev
=
get_free_memory
(
torch_dev
)
mem_dev
=
get_free_memory
(
torch_dev
)
mem_cpu
=
get_free_memory
(
cpu_dev
)
mem_cpu
=
get_free_memory
(
cpu_dev
)
...
...
comfy/sample.py
View file @
51dde87e
...
@@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type):
...
@@ -51,18 +51,20 @@ def get_models_from_cond(cond, model_type):
models
+=
[
c
[
1
][
model_type
]]
models
+=
[
c
[
1
][
model_type
]]
return
models
return
models
def
get_additional_models
(
positive
,
negative
):
def
get_additional_models
(
positive
,
negative
,
dtype
):
"""loads additional models in positive and negative conditioning"""
"""loads additional models in positive and negative conditioning"""
control_nets
=
set
(
get_models_from_cond
(
positive
,
"control"
)
+
get_models_from_cond
(
negative
,
"control"
))
control_nets
=
set
(
get_models_from_cond
(
positive
,
"control"
)
+
get_models_from_cond
(
negative
,
"control"
))
inference_memory
=
0
control_models
=
[]
control_models
=
[]
for
m
in
control_nets
:
for
m
in
control_nets
:
control_models
+=
m
.
get_models
()
control_models
+=
m
.
get_models
()
inference_memory
+=
m
.
inference_memory_requirements
(
dtype
)
gligen
=
get_models_from_cond
(
positive
,
"gligen"
)
+
get_models_from_cond
(
negative
,
"gligen"
)
gligen
=
get_models_from_cond
(
positive
,
"gligen"
)
+
get_models_from_cond
(
negative
,
"gligen"
)
gligen
=
[
x
[
1
]
for
x
in
gligen
]
gligen
=
[
x
[
1
]
for
x
in
gligen
]
models
=
control_models
+
gligen
models
=
control_models
+
gligen
return
models
return
models
,
inference_memory
def
cleanup_additional_models
(
models
):
def
cleanup_additional_models
(
models
):
"""cleanup additional models that were loaded"""
"""cleanup additional models that were loaded"""
...
@@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
...
@@ -77,8 +79,8 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
noise_mask
=
prepare_mask
(
noise_mask
,
noise
.
shape
,
device
)
noise_mask
=
prepare_mask
(
noise_mask
,
noise
.
shape
,
device
)
real_model
=
None
real_model
=
None
models
=
get_additional_models
(
positive
,
negative
)
models
,
inference_memory
=
get_additional_models
(
positive
,
negative
,
model
.
model_dtype
()
)
comfy
.
model_management
.
load_models_gpu
([
model
]
+
models
,
comfy
.
model_management
.
batch_area_memory
(
noise
.
shape
[
0
]
*
noise
.
shape
[
2
]
*
noise
.
shape
[
3
]))
comfy
.
model_management
.
load_models_gpu
([
model
]
+
models
,
comfy
.
model_management
.
batch_area_memory
(
noise
.
shape
[
0
]
*
noise
.
shape
[
2
]
*
noise
.
shape
[
3
])
+
inference_memory
)
real_model
=
model
.
model
real_model
=
model
.
model
noise
=
noise
.
to
(
device
)
noise
=
noise
.
to
(
device
)
...
...
comfy/sd.py
View file @
51dde87e
...
@@ -779,6 +779,11 @@ class ControlBase:
...
@@ -779,6 +779,11 @@ class ControlBase:
c
.
strength
=
self
.
strength
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
c
.
timestep_percent_range
=
self
.
timestep_percent_range
def
inference_memory_requirements
(
self
,
dtype
):
if
self
.
previous_controlnet
is
not
None
:
return
self
.
previous_controlnet
.
inference_memory_requirements
(
dtype
)
return
0
def
control_merge
(
self
,
control_input
,
control_output
,
control_prev
,
output_dtype
):
def
control_merge
(
self
,
control_input
,
control_output
,
control_prev
,
output_dtype
):
out
=
{
'input'
:[],
'middle'
:[],
'output'
:
[]}
out
=
{
'input'
:[],
'middle'
:[],
'output'
:
[]}
...
@@ -985,6 +990,9 @@ class ControlLora(ControlNet):
...
@@ -985,6 +990,9 @@ class ControlLora(ControlNet):
out
=
ControlBase
.
get_models
(
self
)
out
=
ControlBase
.
get_models
(
self
)
return
out
return
out
def
inference_memory_requirements
(
self
,
dtype
):
return
utils
.
calculate_parameters
(
self
.
control_weights
)
*
model_management
.
dtype_size
(
dtype
)
+
ControlBase
.
inference_memory_requirements
(
self
,
dtype
)
def
load_controlnet
(
ckpt_path
,
model
=
None
):
def
load_controlnet
(
ckpt_path
,
model
=
None
):
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
controlnet_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
if
"lora_controlnet"
in
controlnet_data
:
if
"lora_controlnet"
in
controlnet_data
:
...
@@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
...
@@ -1323,13 +1331,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return
(
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch_device
(),
offload_device
=
offload_device
),
clip
,
vae
)
return
(
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch_device
(),
offload_device
=
offload_device
),
clip
,
vae
)
def
calculate_parameters
(
sd
,
prefix
):
params
=
0
for
k
in
sd
.
keys
():
if
k
.
startswith
(
prefix
):
params
+=
sd
[
k
].
nelement
()
return
params
def
load_checkpoint_guess_config
(
ckpt_path
,
output_vae
=
True
,
output_clip
=
True
,
output_clipvision
=
False
,
embedding_directory
=
None
):
def
load_checkpoint_guess_config
(
ckpt_path
,
output_vae
=
True
,
output_clip
=
True
,
output_clipvision
=
False
,
embedding_directory
=
None
):
sd
=
utils
.
load_torch_file
(
ckpt_path
)
sd
=
utils
.
load_torch_file
(
ckpt_path
)
sd_keys
=
sd
.
keys
()
sd_keys
=
sd
.
keys
()
...
@@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -1339,7 +1340,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model
=
None
model
=
None
clip_target
=
None
clip_target
=
None
parameters
=
calculate_parameters
(
sd
,
"model.diffusion_model."
)
parameters
=
utils
.
calculate_parameters
(
sd
,
"model.diffusion_model."
)
fp16
=
model_management
.
should_use_fp16
(
model_params
=
parameters
)
fp16
=
model_management
.
should_use_fp16
(
model_params
=
parameters
)
class
WeightsLoader
(
torch
.
nn
.
Module
):
class
WeightsLoader
(
torch
.
nn
.
Module
):
...
@@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -1390,7 +1391,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def
load_unet
(
unet_path
):
#load unet in diffusers format
def
load_unet
(
unet_path
):
#load unet in diffusers format
sd
=
utils
.
load_torch_file
(
unet_path
)
sd
=
utils
.
load_torch_file
(
unet_path
)
parameters
=
calculate_parameters
(
sd
,
""
)
parameters
=
utils
.
calculate_parameters
(
sd
)
fp16
=
model_management
.
should_use_fp16
(
model_params
=
parameters
)
fp16
=
model_management
.
should_use_fp16
(
model_params
=
parameters
)
model_config
=
model_detection
.
model_config_from_diffusers_unet
(
sd
,
fp16
)
model_config
=
model_detection
.
model_config_from_diffusers_unet
(
sd
,
fp16
)
...
...
comfy/utils.py
View file @
51dde87e
...
@@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None):
...
@@ -32,6 +32,13 @@ def save_torch_file(sd, ckpt, metadata=None):
else
:
else
:
safetensors
.
torch
.
save_file
(
sd
,
ckpt
)
safetensors
.
torch
.
save_file
(
sd
,
ckpt
)
def
calculate_parameters
(
sd
,
prefix
=
""
):
params
=
0
for
k
in
sd
.
keys
():
if
k
.
startswith
(
prefix
):
params
+=
sd
[
k
].
nelement
()
return
params
def
transformers_convert
(
sd
,
prefix_from
,
prefix_to
,
number
):
def
transformers_convert
(
sd
,
prefix_from
,
prefix_to
,
number
):
keys_to_replace
=
{
keys_to_replace
=
{
"{}positional_embedding"
:
"{}embeddings.position_embedding.weight"
,
"{}positional_embedding"
:
"{}embeddings.position_embedding.weight"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment