Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ddc6f12a
Commit
ddc6f12a
authored
Jul 05, 2023
by
comfyanonymous
Browse files
Disable autocast in unet for increased speed.
parent
603f02d6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
87 additions
and
82 deletions
+87
-82
comfy/gligen.py
comfy/gligen.py
+6
-3
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+2
-2
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+5
-5
comfy/ldm/modules/sub_quadratic_attention.py
comfy/ldm/modules/sub_quadratic_attention.py
+2
-2
comfy/model_base.py
comfy/model_base.py
+7
-1
comfy/model_management.py
comfy/model_management.py
+1
-0
comfy/sample.py
comfy/sample.py
+3
-3
comfy/samplers.py
comfy/samplers.py
+59
-65
comfy/sd.py
comfy/sd.py
+2
-1
No files found.
comfy/gligen.py
View file @
ddc6f12a
...
...
@@ -215,10 +215,12 @@ class PositionNet(nn.Module):
def
forward
(
self
,
boxes
,
masks
,
positive_embeddings
):
B
,
N
,
_
=
boxes
.
shape
masks
=
masks
.
unsqueeze
(
-
1
)
dtype
=
self
.
linears
[
0
].
weight
.
dtype
masks
=
masks
.
unsqueeze
(
-
1
).
to
(
dtype
)
positive_embeddings
=
positive_embeddings
.
to
(
dtype
)
# embedding position (it may includes padding as placeholder)
xyxy_embedding
=
self
.
fourier_embedder
(
boxes
)
# B*N*4 --> B*N*C
xyxy_embedding
=
self
.
fourier_embedder
(
boxes
.
to
(
dtype
)
)
# B*N*4 --> B*N*C
# learnable null embedding
positive_null
=
self
.
null_positive_feature
.
view
(
1
,
1
,
-
1
)
...
...
@@ -252,7 +254,8 @@ class Gligen(nn.Module):
if
self
.
lowvram
==
True
:
self
.
position_net
.
cpu
()
def
func_lowvram
(
key
,
x
):
def
func_lowvram
(
x
,
extra_options
):
key
=
extra_options
[
"transformer_index"
]
module
=
self
.
module_list
[
key
]
module
.
to
(
x
.
device
)
r
=
module
(
x
,
objs
)
...
...
comfy/ldm/modules/attention.py
View file @
ddc6f12a
...
...
@@ -278,7 +278,7 @@ class CrossAttentionDoggettx(nn.Module):
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> (b h) n d'
,
h
=
h
),
(
q_in
,
k_in
,
v_in
))
del
q_in
,
k_in
,
v_in
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
)
r1
=
torch
.
zeros
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v
.
shape
[
2
],
device
=
q
.
device
,
dtype
=
q
.
dtype
)
mem_free_total
=
model_management
.
get_free_memory
(
q
.
device
)
...
...
@@ -314,7 +314,7 @@ class CrossAttentionDoggettx(nn.Module):
s1
=
einsum
(
'b i d, b j d -> b i j'
,
q
[:,
i
:
end
],
k
)
*
self
.
scale
first_op_done
=
True
s2
=
s1
.
softmax
(
dim
=-
1
)
s2
=
s1
.
softmax
(
dim
=-
1
)
.
to
(
v
.
dtype
)
del
s1
r1
[:,
i
:
end
]
=
einsum
(
'b i j, b j d -> b i d'
,
s2
,
v
)
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
ddc6f12a
...
...
@@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
self
.
use_scale_shift_norm
=
use_scale_shift_norm
self
.
in_layers
=
nn
.
Sequential
(
n
ormalization
(
channels
,
dtype
=
dtype
),
n
n
.
GroupNorm
(
32
,
channels
,
dtype
=
dtype
),
nn
.
SiLU
(),
conv_nd
(
dims
,
channels
,
self
.
out_channels
,
3
,
padding
=
1
,
dtype
=
dtype
),
)
...
...
@@ -244,7 +244,7 @@ class ResBlock(TimestepBlock):
),
)
self
.
out_layers
=
nn
.
Sequential
(
n
ormalization
(
self
.
out_channels
,
dtype
=
dtype
),
n
n
.
GroupNorm
(
32
,
self
.
out_channels
,
dtype
=
dtype
),
nn
.
SiLU
(),
nn
.
Dropout
(
p
=
dropout
),
zero_module
(
...
...
@@ -778,13 +778,13 @@ class UNetModel(nn.Module):
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
n
ormalization
(
ch
,
dtype
=
self
.
dtype
),
n
n
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
model_channels
,
out_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
)),
)
if
self
.
predict_codebook_ids
:
self
.
id_predictor
=
nn
.
Sequential
(
n
ormalization
(
ch
),
n
n
.
GroupNorm
(
32
,
ch
,
dtype
=
self
.
dtype
),
conv_nd
(
dims
,
model_channels
,
n_embed
,
1
),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
...
...
@@ -821,7 +821,7 @@ class UNetModel(nn.Module):
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
.
to
(
self
.
dtype
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
...
...
comfy/ldm/modules/sub_quadratic_attention.py
View file @
ddc6f12a
...
...
@@ -84,7 +84,7 @@ def _summarize_chunk(
max_score
,
_
=
torch
.
max
(
attn_weights
,
-
1
,
keepdim
=
True
)
max_score
=
max_score
.
detach
()
torch
.
exp
(
attn_weights
-
max_score
,
out
=
attn_weights
)
exp_weights
=
attn_weights
exp_weights
=
attn_weights
.
to
(
value
.
dtype
)
exp_values
=
torch
.
bmm
(
exp_weights
,
value
)
max_score
=
max_score
.
squeeze
(
-
1
)
return
AttnChunk
(
exp_values
,
exp_weights
.
sum
(
dim
=-
1
),
max_score
)
...
...
@@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
attn_scores
/=
summed
attn_probs
=
attn_scores
hidden_states_slice
=
torch
.
bmm
(
attn_probs
,
value
)
hidden_states_slice
=
torch
.
bmm
(
attn_probs
.
to
(
value
.
dtype
)
,
value
)
return
hidden_states_slice
class
ScannedChunk
(
NamedTuple
):
...
...
comfy/model_base.py
View file @
ddc6f12a
...
...
@@ -52,7 +52,13 @@ class BaseModel(torch.nn.Module):
else
:
xc
=
x
context
=
torch
.
cat
(
c_crossattn
,
1
)
return
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
)
dtype
=
self
.
get_dtype
()
xc
=
xc
.
to
(
dtype
)
t
=
t
.
to
(
dtype
)
context
=
context
.
to
(
dtype
)
if
c_adm
is
not
None
:
c_adm
=
c_adm
.
to
(
dtype
)
return
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
).
float
()
def
get_dtype
(
self
):
return
self
.
diffusion_model
.
dtype
...
...
comfy/model_management.py
View file @
ddc6f12a
...
...
@@ -264,6 +264,7 @@ def load_model_gpu(model):
torch_dev
=
model
.
load_device
model
.
model_patches_to
(
torch_dev
)
model
.
model_patches_to
(
model
.
model_dtype
())
if
is_device_cpu
(
torch_dev
):
vram_set_state
=
VRAMState
.
DISABLED
...
...
comfy/sample.py
View file @
ddc6f12a
...
...
@@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type):
models
+=
[
c
[
1
][
model_type
]]
return
models
def
load_additional_models
(
positive
,
negative
):
def
load_additional_models
(
positive
,
negative
,
dtype
):
"""loads additional models in positive and negative conditioning"""
control_nets
=
get_models_from_cond
(
positive
,
"control"
)
+
get_models_from_cond
(
negative
,
"control"
)
gligen
=
get_models_from_cond
(
positive
,
"gligen"
)
+
get_models_from_cond
(
negative
,
"gligen"
)
gligen
=
[
x
[
1
]
for
x
in
gligen
]
gligen
=
[
x
[
1
]
.
to
(
dtype
)
for
x
in
gligen
]
models
=
control_nets
+
gligen
comfy
.
model_management
.
load_controlnet_gpu
(
models
)
return
models
...
...
@@ -81,7 +81,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy
=
broadcast_cond
(
positive
,
noise
.
shape
[
0
],
device
)
negative_copy
=
broadcast_cond
(
negative
,
noise
.
shape
[
0
],
device
)
models
=
load_additional_models
(
positive
,
negative
)
models
=
load_additional_models
(
positive
,
negative
,
model
.
model_dtype
()
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
...
...
comfy/samplers.py
View file @
ddc6f12a
...
...
@@ -2,7 +2,6 @@ from .k_diffusion import sampling as k_diffusion_sampling
from
.k_diffusion
import
external
as
k_diffusion_external
from
.extra_samplers
import
uni_pc
import
torch
import
contextlib
from
comfy
import
model_management
from
.ldm.models.diffusion.ddim
import
DDIMSampler
from
.ldm.modules.diffusionmodules.util
import
make_ddim_timesteps
...
...
@@ -577,11 +576,6 @@ class KSampler:
apply_empty_x_to_equal_area
(
positive
,
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
if
self
.
model
.
get_dtype
()
==
torch
.
float16
:
precision_scope
=
torch
.
autocast
else
:
precision_scope
=
contextlib
.
nullcontext
if
self
.
model
.
is_adm
():
positive
=
encode_adm
(
self
.
model
,
positive
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
self
.
device
,
"positive"
)
negative
=
encode_adm
(
self
.
model
,
negative
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
self
.
device
,
"negative"
)
...
...
@@ -612,67 +606,67 @@ class KSampler:
else
:
max_denoise
=
True
with
precision_scope
(
model_management
.
get_autocast_device
(
self
.
device
)):
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
callback
=
callback
,
disable
=
disable_pbar
)
elif
self
.
sampler
==
"uni_pc_bh2"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
callback
=
callback
,
variant
=
'bh2'
,
disable
=
disable_pbar
)
elif
self
.
sampler
==
"ddim"
:
timesteps
=
[]
for
s
in
range
(
sigmas
.
shape
[
0
]):
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_t
(
sigmas
[
s
]))
noise_mask
=
None
if
denoise_mask
is
not
None
:
noise_mask
=
1.0
-
denoise_mask
ddim_callback
=
None
if
callback
is
not
None
:
total_steps
=
len
(
timesteps
)
-
1
ddim_callback
=
lambda
pred_x0
,
i
:
callback
(
i
,
pred_x0
,
None
,
total_steps
)
sampler
=
DDIMSampler
(
self
.
model
,
device
=
self
.
device
)
sampler
.
make_schedule_timesteps
(
ddim_timesteps
=
timesteps
,
verbose
=
False
)
z_enc
=
sampler
.
stochastic_encode
(
latent_image
,
torch
.
tensor
([
len
(
timesteps
)
-
1
]
*
noise
.
shape
[
0
]).
to
(
self
.
device
),
noise
=
noise
,
max_denoise
=
max_denoise
)
samples
,
_
=
sampler
.
sample_custom
(
ddim_timesteps
=
timesteps
,
conditioning
=
positive
,
batch_size
=
noise
.
shape
[
0
],
shape
=
noise
.
shape
[
1
:],
verbose
=
False
,
unconditional_guidance_scale
=
cfg
,
unconditional_conditioning
=
negative
,
eta
=
0.0
,
x_T
=
z_enc
,
x0
=
latent_image
,
img_callback
=
ddim_callback
,
denoise_function
=
sampling_function
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
end_step
=
sigmas
.
shape
[
0
]
-
1
,
disable_pbar
=
disable_pbar
)
else
:
extra_args
[
"denoise_mask"
]
=
denoise_mask
self
.
model_k
.
latent_image
=
latent_image
self
.
model_k
.
noise
=
noise
if
max_denoise
:
noise
=
noise
*
torch
.
sqrt
(
1.0
+
sigmas
[
0
]
**
2.0
)
else
:
noise
=
noise
*
sigmas
[
0
]
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
callback
=
callback
,
disable
=
disable_pbar
)
elif
self
.
sampler
==
"uni_pc_bh2"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
max_denoise
=
max_denoise
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
,
callback
=
callback
,
variant
=
'bh2'
,
disable
=
disable_pbar
)
elif
self
.
sampler
==
"ddim"
:
timesteps
=
[]
for
s
in
range
(
sigmas
.
shape
[
0
]):
timesteps
.
insert
(
0
,
self
.
model_wrap
.
sigma_to_t
(
sigmas
[
s
]))
noise_mask
=
None
if
denoise_mask
is
not
None
:
noise_mask
=
1.0
-
denoise_mask
ddim_callback
=
None
if
callback
is
not
None
:
total_steps
=
len
(
timesteps
)
-
1
ddim_callback
=
lambda
pred_x0
,
i
:
callback
(
i
,
pred_x0
,
None
,
total_steps
)
sampler
=
DDIMSampler
(
self
.
model
,
device
=
self
.
device
)
sampler
.
make_schedule_timesteps
(
ddim_timesteps
=
timesteps
,
verbose
=
False
)
z_enc
=
sampler
.
stochastic_encode
(
latent_image
,
torch
.
tensor
([
len
(
timesteps
)
-
1
]
*
noise
.
shape
[
0
]).
to
(
self
.
device
),
noise
=
noise
,
max_denoise
=
max_denoise
)
samples
,
_
=
sampler
.
sample_custom
(
ddim_timesteps
=
timesteps
,
conditioning
=
positive
,
batch_size
=
noise
.
shape
[
0
],
shape
=
noise
.
shape
[
1
:],
verbose
=
False
,
unconditional_guidance_scale
=
cfg
,
unconditional_conditioning
=
negative
,
eta
=
0.0
,
x_T
=
z_enc
,
x0
=
latent_image
,
img_callback
=
ddim_callback
,
denoise_function
=
sampling_function
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
end_step
=
sigmas
.
shape
[
0
]
-
1
,
disable_pbar
=
disable_pbar
)
k_callback
=
None
total_steps
=
l
en
(
sigmas
)
-
1
if
callback
is
not
None
:
k_callback
=
lambda
x
:
callback
(
x
[
"i"
],
x
[
"denoised"
],
x
[
"x"
],
total_steps
)
else
:
extra_args
[
"denoise_mask"
]
=
d
en
oise_mask
self
.
model_k
.
latent_image
=
latent_image
self
.
model_k
.
noise
=
noise
if
latent_image
is
not
None
:
noise
+=
latent_image
if
self
.
sampler
==
"dpm_fast"
:
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
total_steps
,
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
elif
self
.
sampler
==
"dpm_adaptive"
:
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
else
:
samples
=
getattr
(
k_diffusion_sampling
,
"sample_{}"
.
format
(
self
.
sampler
))(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
if
max_denoise
:
noise
=
noise
*
torch
.
sqrt
(
1.0
+
sigmas
[
0
]
**
2.0
)
else
:
noise
=
noise
*
sigmas
[
0
]
k_callback
=
None
total_steps
=
len
(
sigmas
)
-
1
if
callback
is
not
None
:
k_callback
=
lambda
x
:
callback
(
x
[
"i"
],
x
[
"denoised"
],
x
[
"x"
],
total_steps
)
if
latent_image
is
not
None
:
noise
+=
latent_image
if
self
.
sampler
==
"dpm_fast"
:
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
total_steps
,
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
elif
self
.
sampler
==
"dpm_adaptive"
:
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
else
:
samples
=
getattr
(
k_diffusion_sampling
,
"sample_{}"
.
format
(
self
.
sampler
))(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
return
self
.
model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
))
comfy/sd.py
View file @
ddc6f12a
...
...
@@ -291,7 +291,8 @@ class ModelPatcher:
patch_list
[
k
]
=
patch_list
[
k
].
to
(
device
)
def
model_dtype
(
self
):
return
self
.
model
.
get_dtype
()
if
hasattr
(
self
.
model
,
"get_dtype"
):
return
self
.
model
.
get_dtype
()
def
add_patches
(
self
,
patches
,
strength_patch
=
1.0
,
strength_model
=
1.0
):
p
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment