Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
6425252c
"comfy/ldm/vscode:/vscode.git/clone" did not exist on "61b3f15f8f2bc0822cb98eac48742fb32f6af396"
Commit
6425252c
authored
Jun 16, 2024
by
comfyanonymous
Browse files
Use fp16 as the default vae dtype for the audio VAE.
parent
8ddc151a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
16 deletions
+24
-16
comfy/model_management.py
comfy/model_management.py
+20
-15
comfy/sd.py
comfy/sd.py
+4
-1
No files found.
comfy/model_management.py
View file @
6425252c
...
...
@@ -167,7 +167,7 @@ if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION
=
True
XFORMERS_IS_AVAILABLE
=
False
VAE_DTYPE
=
torch
.
float32
VAE_DTYPE
S
=
[
torch
.
float32
]
try
:
if
is_nvidia
():
...
...
@@ -176,7 +176,7 @@ try:
if
ENABLE_PYTORCH_ATTENTION
==
False
and
args
.
use_split_cross_attention
==
False
and
args
.
use_quad_cross_attention
==
False
:
ENABLE_PYTORCH_ATTENTION
=
True
if
torch
.
cuda
.
is_bf16_supported
()
and
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
8
:
VAE_DTYPE
=
torch
.
bfloat16
VAE_DTYPE
S
=
[
torch
.
bfloat16
]
+
VAE_DTYPES
if
is_intel_xpu
():
if
args
.
use_split_cross_attention
==
False
and
args
.
use_quad_cross_attention
==
False
:
ENABLE_PYTORCH_ATTENTION
=
True
...
...
@@ -184,17 +184,10 @@ except:
pass
if
is_intel_xpu
():
VAE_DTYPE
=
torch
.
bfloat16
VAE_DTYPE
S
=
[
torch
.
bfloat16
]
+
VAE_DTYPES
if
args
.
cpu_vae
:
VAE_DTYPE
=
torch
.
float32
if
args
.
fp16_vae
:
VAE_DTYPE
=
torch
.
float16
elif
args
.
bf16_vae
:
VAE_DTYPE
=
torch
.
bfloat16
elif
args
.
fp32_vae
:
VAE_DTYPE
=
torch
.
float32
VAE_DTYPES
=
[
torch
.
float32
]
if
ENABLE_PYTORCH_ATTENTION
:
...
...
@@ -258,7 +251,6 @@ try:
except
:
logging
.
warning
(
"Could not pick default device."
)
logging
.
info
(
"VAE dtype: {}"
.
format
(
VAE_DTYPE
))
current_loaded_models
=
[]
...
...
@@ -619,9 +611,22 @@ def vae_offload_device():
else
:
return
torch
.
device
(
"cpu"
)
def
vae_dtype
():
global
VAE_DTYPE
return
VAE_DTYPE
def
vae_dtype
(
device
=
None
,
allowed_dtypes
=
[]):
global
VAE_DTYPES
if
args
.
fp16_vae
:
return
torch
.
float16
elif
args
.
bf16_vae
:
return
torch
.
bfloat16
elif
args
.
fp32_vae
:
return
torch
.
float32
for
d
in
allowed_dtypes
:
if
d
==
torch
.
float16
and
should_use_fp16
(
device
,
prioritize_performance
=
False
):
return
d
if
d
in
VAE_DTYPES
:
return
d
return
VAE_DTYPES
[
0
]
def
get_autocast_device
(
dev
):
if
hasattr
(
dev
,
'type'
):
...
...
comfy/sd.py
View file @
6425252c
...
...
@@ -178,6 +178,7 @@ class VAE:
self
.
output_channels
=
3
self
.
process_input
=
lambda
image
:
image
*
2.0
-
1.0
self
.
process_output
=
lambda
image
:
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
self
.
working_dtypes
=
[
torch
.
bfloat16
,
torch
.
float32
]
if
config
is
None
:
if
"decoder.mid.block_1.mix_factor"
in
sd
:
...
...
@@ -245,6 +246,7 @@ class VAE:
self
.
downscale_ratio
=
2048
self
.
process_output
=
lambda
audio
:
audio
self
.
process_input
=
lambda
audio
:
audio
self
.
working_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
else
:
logging
.
warning
(
"WARNING: No VAE weights detected, VAE not initalized."
)
self
.
first_stage_model
=
None
...
...
@@ -265,12 +267,13 @@ class VAE:
self
.
device
=
device
offload_device
=
model_management
.
vae_offload_device
()
if
dtype
is
None
:
dtype
=
model_management
.
vae_dtype
()
dtype
=
model_management
.
vae_dtype
(
self
.
device
,
self
.
working_dtypes
)
self
.
vae_dtype
=
dtype
self
.
first_stage_model
.
to
(
self
.
vae_dtype
)
self
.
output_device
=
model_management
.
intermediate_device
()
self
.
patcher
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
first_stage_model
,
load_device
=
self
.
device
,
offload_device
=
offload_device
)
logging
.
debug
(
"VAE load device: {}, offload device: {}, dtype: {}"
.
format
(
self
.
device
,
offload_device
,
self
.
vae_dtype
))
def
vae_encode_crop_pixels
(
self
,
pixels
):
dims
=
pixels
.
shape
[
1
:
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment