Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
61ec3c9d
Commit
61ec3c9d
authored
Mar 31, 2023
by
comfyanonymous
Browse files
Add a way to pass options to the transformers blocks.
parent
04b42bad
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
29 deletions
+33
-29
comfy/ldm/models/diffusion/ddim.py
comfy/ldm/models/diffusion/ddim.py
+7
-7
comfy/ldm/models/diffusion/ddpm.py
comfy/ldm/models/diffusion/ddpm.py
+9
-9
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+5
-5
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+7
-6
comfy/samplers.py
comfy/samplers.py
+5
-2
No files found.
comfy/ldm/models/diffusion/ddim.py
View file @
61ec3c9d
...
...
@@ -78,7 +78,7 @@ class DDIMSampler(object):
dynamic_threshold
=
None
,
ucg_schedule
=
None
,
denoise_function
=
None
,
cond_concat
=
None
,
extra_args
=
None
,
to_zero
=
True
,
end_step
=
None
,
**
kwargs
...
...
@@ -101,7 +101,7 @@ class DDIMSampler(object):
dynamic_threshold
=
dynamic_threshold
,
ucg_schedule
=
ucg_schedule
,
denoise_function
=
denoise_function
,
cond_concat
=
cond_concat
,
extra_args
=
extra_args
,
to_zero
=
to_zero
,
end_step
=
end_step
)
...
...
@@ -174,7 +174,7 @@ class DDIMSampler(object):
dynamic_threshold
=
dynamic_threshold
,
ucg_schedule
=
ucg_schedule
,
denoise_function
=
None
,
cond_concat
=
None
extra_args
=
None
)
return
samples
,
intermediates
...
...
@@ -185,7 +185,7 @@ class DDIMSampler(object):
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
,
ucg_schedule
=
None
,
denoise_function
=
None
,
cond_concat
=
None
,
to_zero
=
True
,
end_step
=
None
):
ucg_schedule
=
None
,
denoise_function
=
None
,
extra_args
=
None
,
to_zero
=
True
,
end_step
=
None
):
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
...
...
@@ -225,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
dynamic_threshold
=
dynamic_threshold
,
denoise_function
=
denoise_function
,
cond_concat
=
cond_concat
)
dynamic_threshold
=
dynamic_threshold
,
denoise_function
=
denoise_function
,
extra_args
=
extra_args
)
img
,
pred_x0
=
outs
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
...
...
@@ -249,11 +249,11 @@ class DDIMSampler(object):
def
p_sample_ddim
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
,
denoise_function
=
None
,
cond_concat
=
None
):
dynamic_threshold
=
None
,
denoise_function
=
None
,
extra_args
=
None
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
if
denoise_function
is
not
None
:
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
unconditional_conditioning
,
c
,
unconditional_guidance_scale
,
cond_concat
)
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
**
extra_args
)
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
...
...
comfy/ldm/models/diffusion/ddpm.py
View file @
61ec3c9d
...
...
@@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
self
.
conditioning_key
=
conditioning_key
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
,
control
=
None
):
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
,
control
=
None
,
transformer_options
=
{}
):
if
self
.
conditioning_key
is
None
:
out
=
self
.
diffusion_model
(
x
,
t
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'concat'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
control
=
control
)
out
=
self
.
diffusion_model
(
xc
,
t
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'crossattn'
:
if
not
self
.
sequential_cross_attn
:
cc
=
torch
.
cat
(
c_crossattn
,
1
)
...
...
@@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
,
control
=
control
)
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
else
:
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'hybrid'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
control
=
control
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'hybrid-adm'
:
assert
c_adm
is
not
None
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'crossattn-adm'
:
assert
c_adm
is
not
None
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'adm'
:
cc
=
c_crossattn
[
0
]
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
else
:
raise
NotImplementedError
()
...
...
comfy/ldm/modules/attention.py
View file @
61ec3c9d
...
...
@@ -504,10 +504,10 @@ class BasicTransformerBlock(nn.Module):
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
),
self
.
parameters
(),
self
.
checkpoint
)
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}
):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
,
transformer_options
),
self
.
parameters
(),
self
.
checkpoint
)
def
_forward
(
self
,
x
,
context
=
None
):
def
_forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}
):
x
=
self
.
attn1
(
self
.
norm1
(
x
),
context
=
context
if
self
.
disable_self_attn
else
None
)
+
x
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
...
...
@@ -557,7 +557,7 @@ class SpatialTransformer(nn.Module):
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
in_channels
,
inner_dim
))
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}
):
# note: if no context is given, cross-attention defaults to self-attention
if
not
isinstance
(
context
,
list
):
context
=
[
context
]
...
...
@@ -570,7 +570,7 @@ class SpatialTransformer(nn.Module):
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
context
=
context
[
i
])
x
=
block
(
x
,
context
=
context
[
i
]
,
transformer_options
=
transformer_options
)
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
).
contiguous
()
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
61ec3c9d
...
...
@@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
,
transformer_options
=
{}
):
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
x
=
layer
(
x
,
context
,
transformer_options
)
else
:
x
=
layer
(
x
)
return
x
...
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
transformer_options
=
{},
**
kwargs
):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
...
...
@@ -762,6 +762,7 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options
[
"original_shape"
]
=
list
(
x
.
shape
)
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
...
...
@@ -775,13 +776,13 @@ class UNetModel(nn.Module):
h
=
x
.
type
(
self
.
dtype
)
for
id
,
module
in
enumerate
(
self
.
input_blocks
):
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
,
transformer_options
)
if
control
is
not
None
and
'input'
in
control
and
len
(
control
[
'input'
])
>
0
:
ctrl
=
control
[
'input'
].
pop
()
if
ctrl
is
not
None
:
h
+=
ctrl
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
h
=
self
.
middle_block
(
h
,
emb
,
context
,
transformer_options
)
if
control
is
not
None
and
'middle'
in
control
and
len
(
control
[
'middle'
])
>
0
:
h
+=
control
[
'middle'
].
pop
()
...
...
@@ -793,7 +794,7 @@ class UNetModel(nn.Module):
hsp
+=
ctrl
h
=
th
.
cat
([
h
,
hsp
],
dim
=
1
)
del
hsp
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
,
transformer_options
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
return
self
.
id_predictor
(
h
)
...
...
comfy/samplers.py
View file @
61ec3c9d
...
...
@@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
):
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
...
...
@@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if
control
is
not
None
:
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
[
'c_crossattn'
],
len
(
cond_or_uncond
))
if
'transformer_options'
in
model_options
:
c
[
'transformer_options'
]
=
model_options
[
'transformer_options'
]
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
...
...
@@ -467,7 +470,7 @@ class KSampler:
x_T
=
z_enc
,
x0
=
latent_image
,
denoise_function
=
sampling_function
,
cond_concat
=
cond_concat
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
end_step
=
sigmas
.
shape
[
0
]
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment