Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
983ebc57
"tests/models/vscode:/vscode.git/clone" did not exist on "4974b84564d25bd4b5c594db4e04cb885cc0a9ed"
Commit
983ebc57
authored
Nov 28, 2023
by
comfyanonymous
Browse files
Use smart model management for VAE to decrease latency.
parent
798a34d0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
11 deletions
+7
-11
comfy/sd.py
comfy/sd.py
+7
-11
No files found.
comfy/sd.py
View file @
983ebc57
...
@@ -187,10 +187,12 @@ class VAE:
...
@@ -187,10 +187,12 @@ class VAE:
if
device
is
None
:
if
device
is
None
:
device
=
model_management
.
vae_device
()
device
=
model_management
.
vae_device
()
self
.
device
=
device
self
.
device
=
device
self
.
offload_device
=
model_management
.
vae_offload_device
()
offload_device
=
model_management
.
vae_offload_device
()
self
.
vae_dtype
=
model_management
.
vae_dtype
()
self
.
vae_dtype
=
model_management
.
vae_dtype
()
self
.
first_stage_model
.
to
(
self
.
vae_dtype
)
self
.
first_stage_model
.
to
(
self
.
vae_dtype
)
self
.
patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
first_stage_model
,
load_device
=
self
.
device
,
offload_device
=
offload_device
)
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
steps
=
samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
,
tile_y
,
overlap
)
steps
=
samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
,
tile_y
,
overlap
)
steps
+=
samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
//
2
,
tile_y
*
2
,
overlap
)
steps
+=
samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
//
2
,
tile_y
*
2
,
overlap
)
...
@@ -219,10 +221,9 @@ class VAE:
...
@@ -219,10 +221,9 @@ class VAE:
return
samples
return
samples
def
decode
(
self
,
samples_in
):
def
decode
(
self
,
samples_in
):
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
try
:
try
:
memory_used
=
self
.
memory_used_decode
(
samples_in
.
shape
,
self
.
vae_dtype
)
memory_used
=
self
.
memory_used_decode
(
samples_in
.
shape
,
self
.
vae_dtype
)
model_management
.
free_memory
(
memory_used
,
self
.
device
)
model_management
.
load_models_gpu
([
self
.
patcher
],
memory_required
=
memory_used
)
free_memory
=
model_management
.
get_free_memory
(
self
.
device
)
free_memory
=
model_management
.
get_free_memory
(
self
.
device
)
batch_number
=
int
(
free_memory
/
memory_used
)
batch_number
=
int
(
free_memory
/
memory_used
)
batch_number
=
max
(
1
,
batch_number
)
batch_number
=
max
(
1
,
batch_number
)
...
@@ -235,22 +236,19 @@ class VAE:
...
@@ -235,22 +236,19 @@ class VAE:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
offload_device
)
pixel_samples
=
pixel_samples
.
cpu
().
movedim
(
1
,
-
1
)
pixel_samples
=
pixel_samples
.
cpu
().
movedim
(
1
,
-
1
)
return
pixel_samples
return
pixel_samples
def
decode_tiled
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
def
decode_tiled
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
self
.
first_stage_model
=
self
.
first_stage
_model
.
to
(
self
.
device
)
model_management
.
load
_model
_gpu
(
self
.
patcher
)
output
=
self
.
decode_tiled_
(
samples
,
tile_x
,
tile_y
,
overlap
)
output
=
self
.
decode_tiled_
(
samples
,
tile_x
,
tile_y
,
overlap
)
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
offload_device
)
return
output
.
movedim
(
1
,
-
1
)
return
output
.
movedim
(
1
,
-
1
)
def
encode
(
self
,
pixel_samples
):
def
encode
(
self
,
pixel_samples
):
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
device
)
pixel_samples
=
pixel_samples
.
movedim
(
-
1
,
1
)
pixel_samples
=
pixel_samples
.
movedim
(
-
1
,
1
)
try
:
try
:
memory_used
=
self
.
memory_used_encode
(
pixel_samples
.
shape
,
self
.
vae_dtype
)
memory_used
=
self
.
memory_used_encode
(
pixel_samples
.
shape
,
self
.
vae_dtype
)
model_management
.
free_memory
(
memory_used
,
self
.
device
)
model_management
.
load_models_gpu
([
self
.
patcher
],
memory_required
=
memory_used
)
free_memory
=
model_management
.
get_free_memory
(
self
.
device
)
free_memory
=
model_management
.
get_free_memory
(
self
.
device
)
batch_number
=
int
(
free_memory
/
memory_used
)
batch_number
=
int
(
free_memory
/
memory_used
)
batch_number
=
max
(
1
,
batch_number
)
batch_number
=
max
(
1
,
batch_number
)
...
@@ -263,14 +261,12 @@ class VAE:
...
@@ -263,14 +261,12 @@ class VAE:
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
samples
=
self
.
encode_tiled_
(
pixel_samples
)
samples
=
self
.
encode_tiled_
(
pixel_samples
)
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
offload_device
)
return
samples
return
samples
def
encode_tiled
(
self
,
pixel_samples
,
tile_x
=
512
,
tile_y
=
512
,
overlap
=
64
):
def
encode_tiled
(
self
,
pixel_samples
,
tile_x
=
512
,
tile_y
=
512
,
overlap
=
64
):
self
.
first_stage_model
=
self
.
first_stage
_model
.
to
(
self
.
device
)
model_management
.
load
_model
_gpu
(
self
.
patcher
)
pixel_samples
=
pixel_samples
.
movedim
(
-
1
,
1
)
pixel_samples
=
pixel_samples
.
movedim
(
-
1
,
1
)
samples
=
self
.
encode_tiled_
(
pixel_samples
,
tile_x
=
tile_x
,
tile_y
=
tile_y
,
overlap
=
overlap
)
samples
=
self
.
encode_tiled_
(
pixel_samples
,
tile_x
=
tile_x
,
tile_y
=
tile_y
,
overlap
=
overlap
)
self
.
first_stage_model
=
self
.
first_stage_model
.
to
(
self
.
offload_device
)
return
samples
return
samples
def
get_sd
(
self
):
def
get_sd
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment