Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e7bee85d
Commit
e7bee85d
authored
Jul 06, 2023
by
comfyanonymous
Browse files
Add arguments to run the VAE in fp16 or bf16 for testing.
parent
f5232c48
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
7 deletions
+21
-7
.github/workflows/windows_release_nightly_pytorch.yml
.github/workflows/windows_release_nightly_pytorch.yml
+1
-1
comfy/cli_args.py
comfy/cli_args.py
+4
-0
comfy/model_management.py
comfy/model_management.py
+8
-0
comfy/sd.py
comfy/sd.py
+8
-6
No files found.
.github/workflows/windows_release_nightly_pytorch.yml
View file @
e7bee85d
...
@@ -54,7 +54,7 @@ jobs:
...
@@ -54,7 +54,7 @@ jobs:
cd ..
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on
-mf=BCJ2
ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
cd ComfyUI_windows_portable_nightly_pytorch
cd ComfyUI_windows_portable_nightly_pytorch
...
...
comfy/cli_args.py
View file @
e7bee85d
...
@@ -46,6 +46,10 @@ fp_group = parser.add_mutually_exclusive_group()
...
@@ -46,6 +46,10 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group
.
add_argument
(
"--force-fp32"
,
action
=
"store_true"
,
help
=
"Force fp32 (If this makes your GPU work better please report it)."
)
fp_group
.
add_argument
(
"--force-fp32"
,
action
=
"store_true"
,
help
=
"Force fp32 (If this makes your GPU work better please report it)."
)
fp_group
.
add_argument
(
"--force-fp16"
,
action
=
"store_true"
,
help
=
"Force fp16."
)
fp_group
.
add_argument
(
"--force-fp16"
,
action
=
"store_true"
,
help
=
"Force fp16."
)
fpvae_group
=
parser
.
add_mutually_exclusive_group
()
fpvae_group
.
add_argument
(
"--fp16-vae"
,
action
=
"store_true"
,
help
=
"Run the VAE in fp16, might cause black images."
)
fpvae_group
.
add_argument
(
"--bf16-vae"
,
action
=
"store_true"
,
help
=
"Run the VAE in bf16, might lower quality."
)
parser
.
add_argument
(
"--directml"
,
type
=
int
,
nargs
=
"?"
,
metavar
=
"DIRECTML_DEVICE"
,
const
=-
1
,
help
=
"Use torch-directml."
)
parser
.
add_argument
(
"--directml"
,
type
=
int
,
nargs
=
"?"
,
metavar
=
"DIRECTML_DEVICE"
,
const
=-
1
,
help
=
"Use torch-directml."
)
class
LatentPreviewMethod
(
enum
.
Enum
):
class
LatentPreviewMethod
(
enum
.
Enum
):
...
...
comfy/model_management.py
View file @
e7bee85d
...
@@ -366,6 +366,14 @@ def vae_offload_device():
...
@@ -366,6 +366,14 @@ def vae_offload_device():
else
:
else
:
return
torch
.
device
(
"cpu"
)
return
torch
.
device
(
"cpu"
)
def
vae_dtype
():
if
args
.
fp16_vae
:
return
torch
.
float16
elif
args
.
bf16_vae
:
return
torch
.
bfloat16
else
:
return
torch
.
float32
def
get_autocast_device
(
dev
):
def
get_autocast_device
(
dev
):
if
hasattr
(
dev
,
'type'
):
if
hasattr
(
dev
,
'type'
):
return
dev
.
type
return
dev
.
type
...
...
comfy/sd.py
View file @
e7bee85d
...
@@ -505,6 +505,8 @@ class VAE:
...
@@ -505,6 +505,8 @@ class VAE:
device
=
model_management
.
vae_device
()
device
=
model_management
.
vae_device
()
self
.
device
=
device
self
.
device
=
device
self
.
offload_device
=
model_management
.
vae_offload_device
()
self
.
offload_device
=
model_management
.
vae_offload_device
()
self
.
vae_dtype
=
model_management
.
vae_dtype
()
self
.
first_stage_model
.
to
(
self
.
vae_dtype
)
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
def
decode_tiled_
(
self
,
samples
,
tile_x
=
64
,
tile_y
=
64
,
overlap
=
16
):
steps
=
samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
,
tile_y
,
overlap
)
steps
=
samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
,
tile_y
,
overlap
)
...
@@ -512,7 +514,7 @@ class VAE:
...
@@ -512,7 +514,7 @@ class VAE:
steps
+=
samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
steps
+=
samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
utils
.
ProgressBar
(
steps
)
pbar
=
utils
.
ProgressBar
(
steps
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
device
))
+
1.0
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
vae_dtype
).
to
(
self
.
device
))
+
1.0
)
.
float
()
output
=
torch
.
clamp
((
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
...
@@ -526,7 +528,7 @@ class VAE:
...
@@ -526,7 +528,7 @@ class VAE:
steps
+=
pixel_samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
steps
+=
pixel_samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
utils
.
ProgressBar
(
steps
)
pbar
=
utils
.
ProgressBar
(
steps
)
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
(
2.
*
a
.
to
(
self
.
device
)
-
1.
).
sample
()
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
(
2.
*
a
.
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
-
1.
).
sample
()
.
float
()
samples
=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
...
@@ -543,8 +545,8 @@ class VAE:
...
@@ -543,8 +545,8 @@ class VAE:
pixel_samples
=
torch
.
empty
((
samples_in
.
shape
[
0
],
3
,
round
(
samples_in
.
shape
[
2
]
*
8
),
round
(
samples_in
.
shape
[
3
]
*
8
)),
device
=
"cpu"
)
pixel_samples
=
torch
.
empty
((
samples_in
.
shape
[
0
],
3
,
round
(
samples_in
.
shape
[
2
]
*
8
),
round
(
samples_in
.
shape
[
3
]
*
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
samples_in
.
shape
[
0
],
batch_number
):
for
x
in
range
(
0
,
samples_in
.
shape
[
0
],
batch_number
):
samples
=
samples_in
[
x
:
x
+
batch_number
].
to
(
self
.
device
)
samples
=
samples_in
[
x
:
x
+
batch_number
].
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
((
self
.
first_stage_model
.
decode
(
samples
)
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
((
self
.
first_stage_model
.
decode
(
samples
)
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
.
float
()
except
model_management
.
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
...
@@ -570,8 +572,8 @@ class VAE:
...
@@ -570,8 +572,8 @@ class VAE:
batch_number
=
max
(
1
,
batch_number
)
batch_number
=
max
(
1
,
batch_number
)
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
4
,
round
(
pixel_samples
.
shape
[
2
]
//
8
),
round
(
pixel_samples
.
shape
[
3
]
//
8
)),
device
=
"cpu"
)
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
4
,
round
(
pixel_samples
.
shape
[
2
]
//
8
),
round
(
pixel_samples
.
shape
[
3
]
//
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
device
)
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
()
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
()
.
float
()
except
model_management
.
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment