Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
264caca2
"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "6a10640f0dd019dd7c74006909f38d0056c317bd"
Commit
264caca2
authored
Jun 26, 2024
by
comfyanonymous
Browse files
ControlNetApplySD3 node can now be used to use SD3 controlnets.
parent
f8f7568d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
11 deletions
+41
-11
comfy/cldm/mmdit.py
comfy/cldm/mmdit.py
+0
-5
comfy/controlnet.py
comfy/controlnet.py
+24
-4
comfy_extras/nodes_sd3.py
comfy_extras/nodes_sd3.py
+15
-0
nodes.py
nodes.py
+2
-2
No files found.
comfy/cldm/mmdit.py
View file @
264caca2
import
torch
from
typing
import
Dict
,
Optional
import
comfy.ldm.modules.diffusionmodules.mmdit
import
comfy.latent_formats
class
ControlNet
(
comfy
.
ldm
.
modules
.
diffusionmodules
.
mmdit
.
MMDiT
):
def
__init__
(
...
...
@@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
operations
=
operations
)
self
.
latent_format
=
comfy
.
latent_formats
.
SD3
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
)
->
torch
.
Tensor
:
#weird sd3 controlnet specific stuff
hint
=
hint
*
self
.
latent_format
.
scale_factor
# self.latent_format.process_in(hint)
y
=
torch
.
zeros_like
(
y
)
if
self
.
context_processor
is
not
None
:
context
=
self
.
context_processor
(
context
)
...
...
comfy/controlnet.py
View file @
264caca2
...
...
@@ -7,6 +7,7 @@ import comfy.model_management
import
comfy.model_detection
import
comfy.model_patcher
import
comfy.ops
import
comfy.latent_formats
import
comfy.cldm.cldm
import
comfy.t2i_adapter.adapter
...
...
@@ -38,6 +39,8 @@ class ControlBase:
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
0.0
,
1.0
)
self
.
latent_format
=
None
self
.
vae
=
None
self
.
global_average_pooling
=
False
self
.
timestep_range
=
None
self
.
compression_ratio
=
8
...
...
@@ -48,10 +51,12 @@ class ControlBase:
self
.
device
=
device
self
.
previous_controlnet
=
None
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)):
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)
,
vae
=
None
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
self
.
timestep_percent_range
=
timestep_percent_range
if
self
.
latent_format
is
not
None
:
self
.
vae
=
vae
return
self
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
...
...
@@ -84,6 +89,8 @@ class ControlBase:
c
.
global_average_pooling
=
self
.
global_average_pooling
c
.
compression_ratio
=
self
.
compression_ratio
c
.
upscale_algorithm
=
self
.
upscale_algorithm
c
.
latent_format
=
self
.
latent_format
c
.
vae
=
self
.
vae
def
inference_memory_requirements
(
self
,
dtype
):
if
self
.
previous_controlnet
is
not
None
:
...
...
@@ -129,7 +136,7 @@ class ControlBase:
return
out
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
latent_format
=
None
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
load_device
=
load_device
...
...
@@ -140,6 +147,7 @@ class ControlNet(ControlBase):
self
.
global_average_pooling
=
global_average_pooling
self
.
model_sampling_current
=
None
self
.
manual_cast_dtype
=
manual_cast_dtype
self
.
latent_format
=
latent_format
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
...
...
@@ -162,7 +170,17 @@ class ControlNet(ControlBase):
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
cond_hint
=
comfy
.
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
self
.
compression_ratio
,
x_noisy
.
shape
[
2
]
*
self
.
compression_ratio
,
self
.
upscale_algorithm
,
"center"
).
to
(
dtype
).
to
(
self
.
device
)
compression_ratio
=
self
.
compression_ratio
if
self
.
vae
is
not
None
:
compression_ratio
*=
self
.
vae
.
downscale_ratio
self
.
cond_hint
=
comfy
.
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
compression_ratio
,
x_noisy
.
shape
[
2
]
*
compression_ratio
,
self
.
upscale_algorithm
,
"center"
)
if
self
.
vae
is
not
None
:
loaded_models
=
comfy
.
model_management
.
loaded_models
(
only_currently_used
=
True
)
self
.
cond_hint
=
self
.
vae
.
encode
(
self
.
cond_hint
.
movedim
(
1
,
-
1
))
comfy
.
model_management
.
load_models_gpu
(
loaded_models
)
if
self
.
latent_format
is
not
None
:
self
.
cond_hint
=
self
.
latent_format
.
process_in
(
self
.
cond_hint
)
self
.
cond_hint
=
self
.
cond_hint
.
to
(
device
=
self
.
device
,
dtype
=
dtype
)
if
x_noisy
.
shape
[
0
]
!=
self
.
cond_hint
.
shape
[
0
]:
self
.
cond_hint
=
broadcast_image_to
(
self
.
cond_hint
,
x_noisy
.
shape
[
0
],
batched_number
)
...
...
@@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd):
if
len
(
unexpected
)
>
0
:
logging
.
debug
(
"unexpected controlnet keys: {}"
.
format
(
unexpected
))
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
latent_format
=
comfy
.
latent_formats
.
SD3
()
latent_format
.
shift_factor
=
0
#SD3 controlnet weirdness
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
latent_format
=
latent_format
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
return
control
...
...
comfy_extras/nodes_sd3.py
View file @
264caca2
...
...
@@ -80,8 +80,23 @@ class CLIPTextEncodeSD3:
return
([[
cond
,
{
"pooled_output"
:
pooled
}]],
)
class
ControlNetApplySD3
(
nodes
.
ControlNetApplyAdvanced
):
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"positive"
:
(
"CONDITIONING"
,
),
"negative"
:
(
"CONDITIONING"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"vae"
:
(
"VAE"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"start_percent"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end_percent"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
})
}}
CATEGORY
=
"_for_testing/sd3"
NODE_CLASS_MAPPINGS
=
{
"TripleCLIPLoader"
:
TripleCLIPLoader
,
"EmptySD3LatentImage"
:
EmptySD3LatentImage
,
"CLIPTextEncodeSD3"
:
CLIPTextEncodeSD3
,
"ControlNetApplySD3"
:
ControlNetApplySD3
,
}
nodes.py
View file @
264caca2
...
...
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start_percent
,
end_percent
):
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start_percent
,
end_percent
,
vae
=
None
):
if
strength
==
0
:
return
(
positive
,
negative
)
...
...
@@ -800,7 +800,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
))
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
)
,
vae
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment