Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
264caca2
Commit
264caca2
authored
Jun 26, 2024
by
comfyanonymous
Browse files
ControlNetApplySD3 node can now be used to use SD3 controlnets.
parent
f8f7568d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
11 deletions
+41
-11
comfy/cldm/mmdit.py
comfy/cldm/mmdit.py
+0
-5
comfy/controlnet.py
comfy/controlnet.py
+24
-4
comfy_extras/nodes_sd3.py
comfy_extras/nodes_sd3.py
+15
-0
nodes.py
nodes.py
+2
-2
No files found.
comfy/cldm/mmdit.py
View file @
264caca2
import
torch
from
typing
import
Dict
,
Optional
import
comfy.ldm.modules.diffusionmodules.mmdit
import
comfy.latent_formats
class
ControlNet
(
comfy
.
ldm
.
modules
.
diffusionmodules
.
mmdit
.
MMDiT
):
def
__init__
(
...
...
@@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
operations
=
operations
)
self
.
latent_format
=
comfy
.
latent_formats
.
SD3
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
)
->
torch
.
Tensor
:
#weird sd3 controlnet specific stuff
hint
=
hint
*
self
.
latent_format
.
scale_factor
# self.latent_format.process_in(hint)
y
=
torch
.
zeros_like
(
y
)
if
self
.
context_processor
is
not
None
:
context
=
self
.
context_processor
(
context
)
...
...
comfy/controlnet.py
View file @
264caca2
...
...
@@ -7,6 +7,7 @@ import comfy.model_management
import
comfy.model_detection
import
comfy.model_patcher
import
comfy.ops
import
comfy.latent_formats
import
comfy.cldm.cldm
import
comfy.t2i_adapter.adapter
...
...
@@ -38,6 +39,8 @@ class ControlBase:
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
0.0
,
1.0
)
self
.
latent_format
=
None
self
.
vae
=
None
self
.
global_average_pooling
=
False
self
.
timestep_range
=
None
self
.
compression_ratio
=
8
...
...
@@ -48,10 +51,12 @@ class ControlBase:
self
.
device
=
device
self
.
previous_controlnet
=
None
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)):
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)
,
vae
=
None
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
self
.
timestep_percent_range
=
timestep_percent_range
if
self
.
latent_format
is
not
None
:
self
.
vae
=
vae
return
self
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
...
...
@@ -84,6 +89,8 @@ class ControlBase:
c
.
global_average_pooling
=
self
.
global_average_pooling
c
.
compression_ratio
=
self
.
compression_ratio
c
.
upscale_algorithm
=
self
.
upscale_algorithm
c
.
latent_format
=
self
.
latent_format
c
.
vae
=
self
.
vae
def
inference_memory_requirements
(
self
,
dtype
):
if
self
.
previous_controlnet
is
not
None
:
...
...
@@ -129,7 +136,7 @@ class ControlBase:
return
out
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
latent_format
=
None
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
load_device
=
load_device
...
...
@@ -140,6 +147,7 @@ class ControlNet(ControlBase):
self
.
global_average_pooling
=
global_average_pooling
self
.
model_sampling_current
=
None
self
.
manual_cast_dtype
=
manual_cast_dtype
self
.
latent_format
=
latent_format
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
...
...
@@ -162,7 +170,17 @@ class ControlNet(ControlBase):
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
cond_hint
=
comfy
.
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
self
.
compression_ratio
,
x_noisy
.
shape
[
2
]
*
self
.
compression_ratio
,
self
.
upscale_algorithm
,
"center"
).
to
(
dtype
).
to
(
self
.
device
)
compression_ratio
=
self
.
compression_ratio
if
self
.
vae
is
not
None
:
compression_ratio
*=
self
.
vae
.
downscale_ratio
self
.
cond_hint
=
comfy
.
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
compression_ratio
,
x_noisy
.
shape
[
2
]
*
compression_ratio
,
self
.
upscale_algorithm
,
"center"
)
if
self
.
vae
is
not
None
:
loaded_models
=
comfy
.
model_management
.
loaded_models
(
only_currently_used
=
True
)
self
.
cond_hint
=
self
.
vae
.
encode
(
self
.
cond_hint
.
movedim
(
1
,
-
1
))
comfy
.
model_management
.
load_models_gpu
(
loaded_models
)
if
self
.
latent_format
is
not
None
:
self
.
cond_hint
=
self
.
latent_format
.
process_in
(
self
.
cond_hint
)
self
.
cond_hint
=
self
.
cond_hint
.
to
(
device
=
self
.
device
,
dtype
=
dtype
)
if
x_noisy
.
shape
[
0
]
!=
self
.
cond_hint
.
shape
[
0
]:
self
.
cond_hint
=
broadcast_image_to
(
self
.
cond_hint
,
x_noisy
.
shape
[
0
],
batched_number
)
...
...
@@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd):
if
len
(
unexpected
)
>
0
:
logging
.
debug
(
"unexpected controlnet keys: {}"
.
format
(
unexpected
))
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
latent_format
=
comfy
.
latent_formats
.
SD3
()
latent_format
.
shift_factor
=
0
#SD3 controlnet weirdness
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
latent_format
=
latent_format
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
return
control
...
...
comfy_extras/nodes_sd3.py
View file @
264caca2
...
...
@@ -80,8 +80,23 @@ class CLIPTextEncodeSD3:
return
([[
cond
,
{
"pooled_output"
:
pooled
}]],
)
class
ControlNetApplySD3
(
nodes
.
ControlNetApplyAdvanced
):
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"positive"
:
(
"CONDITIONING"
,
),
"negative"
:
(
"CONDITIONING"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"vae"
:
(
"VAE"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"start_percent"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end_percent"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
})
}}
CATEGORY
=
"_for_testing/sd3"
NODE_CLASS_MAPPINGS
=
{
"TripleCLIPLoader"
:
TripleCLIPLoader
,
"EmptySD3LatentImage"
:
EmptySD3LatentImage
,
"CLIPTextEncodeSD3"
:
CLIPTextEncodeSD3
,
"ControlNetApplySD3"
:
ControlNetApplySD3
,
}
nodes.py
View file @
264caca2
...
...
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start_percent
,
end_percent
):
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start_percent
,
end_percent
,
vae
=
None
):
if
strength
==
0
:
return
(
positive
,
negative
)
...
...
@@ -800,7 +800,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
))
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
)
,
vae
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment