Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
264caca2
Commit
264caca2
authored
Jun 26, 2024
by
comfyanonymous
Browse files
ControlNetApplySD3 node can now be used to use SD3 controlnets.
parent
f8f7568d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
11 deletions
+41
-11
comfy/cldm/mmdit.py
comfy/cldm/mmdit.py
+0
-5
comfy/controlnet.py
comfy/controlnet.py
+24
-4
comfy_extras/nodes_sd3.py
comfy_extras/nodes_sd3.py
+15
-0
nodes.py
nodes.py
+2
-2
No files found.
comfy/cldm/mmdit.py
View file @
264caca2
import
torch
import
torch
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
import
comfy.ldm.modules.diffusionmodules.mmdit
import
comfy.ldm.modules.diffusionmodules.mmdit
import
comfy.latent_formats
class
ControlNet
(
comfy
.
ldm
.
modules
.
diffusionmodules
.
mmdit
.
MMDiT
):
class
ControlNet
(
comfy
.
ldm
.
modules
.
diffusionmodules
.
mmdit
.
MMDiT
):
def
__init__
(
def
__init__
(
...
@@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
...
@@ -30,8 +29,6 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
operations
=
operations
operations
=
operations
)
)
self
.
latent_format
=
comfy
.
latent_formats
.
SD3
()
def
forward
(
def
forward
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
...
@@ -42,10 +39,8 @@ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
#weird sd3 controlnet specific stuff
#weird sd3 controlnet specific stuff
hint
=
hint
*
self
.
latent_format
.
scale_factor
# self.latent_format.process_in(hint)
y
=
torch
.
zeros_like
(
y
)
y
=
torch
.
zeros_like
(
y
)
if
self
.
context_processor
is
not
None
:
if
self
.
context_processor
is
not
None
:
context
=
self
.
context_processor
(
context
)
context
=
self
.
context_processor
(
context
)
...
...
comfy/controlnet.py
View file @
264caca2
...
@@ -7,6 +7,7 @@ import comfy.model_management
...
@@ -7,6 +7,7 @@ import comfy.model_management
import
comfy.model_detection
import
comfy.model_detection
import
comfy.model_patcher
import
comfy.model_patcher
import
comfy.ops
import
comfy.ops
import
comfy.latent_formats
import
comfy.cldm.cldm
import
comfy.cldm.cldm
import
comfy.t2i_adapter.adapter
import
comfy.t2i_adapter.adapter
...
@@ -38,6 +39,8 @@ class ControlBase:
...
@@ -38,6 +39,8 @@ class ControlBase:
self
.
cond_hint
=
None
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
0.0
,
1.0
)
self
.
timestep_percent_range
=
(
0.0
,
1.0
)
self
.
latent_format
=
None
self
.
vae
=
None
self
.
global_average_pooling
=
False
self
.
global_average_pooling
=
False
self
.
timestep_range
=
None
self
.
timestep_range
=
None
self
.
compression_ratio
=
8
self
.
compression_ratio
=
8
...
@@ -48,10 +51,12 @@ class ControlBase:
...
@@ -48,10 +51,12 @@ class ControlBase:
self
.
device
=
device
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
previous_controlnet
=
None
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)):
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)
,
vae
=
None
):
self
.
cond_hint_original
=
cond_hint
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
self
.
strength
=
strength
self
.
timestep_percent_range
=
timestep_percent_range
self
.
timestep_percent_range
=
timestep_percent_range
if
self
.
latent_format
is
not
None
:
self
.
vae
=
vae
return
self
return
self
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
def
pre_run
(
self
,
model
,
percent_to_timestep_function
):
...
@@ -84,6 +89,8 @@ class ControlBase:
...
@@ -84,6 +89,8 @@ class ControlBase:
c
.
global_average_pooling
=
self
.
global_average_pooling
c
.
global_average_pooling
=
self
.
global_average_pooling
c
.
compression_ratio
=
self
.
compression_ratio
c
.
compression_ratio
=
self
.
compression_ratio
c
.
upscale_algorithm
=
self
.
upscale_algorithm
c
.
upscale_algorithm
=
self
.
upscale_algorithm
c
.
latent_format
=
self
.
latent_format
c
.
vae
=
self
.
vae
def
inference_memory_requirements
(
self
,
dtype
):
def
inference_memory_requirements
(
self
,
dtype
):
if
self
.
previous_controlnet
is
not
None
:
if
self
.
previous_controlnet
is
not
None
:
...
@@ -129,7 +136,7 @@ class ControlBase:
...
@@ -129,7 +136,7 @@ class ControlBase:
return
out
return
out
class
ControlNet
(
ControlBase
):
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
latent_format
=
None
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
super
().
__init__
(
device
)
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
control_model
=
control_model
self
.
load_device
=
load_device
self
.
load_device
=
load_device
...
@@ -140,6 +147,7 @@ class ControlNet(ControlBase):
...
@@ -140,6 +147,7 @@ class ControlNet(ControlBase):
self
.
global_average_pooling
=
global_average_pooling
self
.
global_average_pooling
=
global_average_pooling
self
.
model_sampling_current
=
None
self
.
model_sampling_current
=
None
self
.
manual_cast_dtype
=
manual_cast_dtype
self
.
manual_cast_dtype
=
manual_cast_dtype
self
.
latent_format
=
latent_format
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
control_prev
=
None
...
@@ -162,7 +170,17 @@ class ControlNet(ControlBase):
...
@@ -162,7 +170,17 @@ class ControlNet(ControlBase):
if
self
.
cond_hint
is
not
None
:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
cond_hint
=
None
self
.
cond_hint
=
comfy
.
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
self
.
compression_ratio
,
x_noisy
.
shape
[
2
]
*
self
.
compression_ratio
,
self
.
upscale_algorithm
,
"center"
).
to
(
dtype
).
to
(
self
.
device
)
compression_ratio
=
self
.
compression_ratio
if
self
.
vae
is
not
None
:
compression_ratio
*=
self
.
vae
.
downscale_ratio
self
.
cond_hint
=
comfy
.
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
compression_ratio
,
x_noisy
.
shape
[
2
]
*
compression_ratio
,
self
.
upscale_algorithm
,
"center"
)
if
self
.
vae
is
not
None
:
loaded_models
=
comfy
.
model_management
.
loaded_models
(
only_currently_used
=
True
)
self
.
cond_hint
=
self
.
vae
.
encode
(
self
.
cond_hint
.
movedim
(
1
,
-
1
))
comfy
.
model_management
.
load_models_gpu
(
loaded_models
)
if
self
.
latent_format
is
not
None
:
self
.
cond_hint
=
self
.
latent_format
.
process_in
(
self
.
cond_hint
)
self
.
cond_hint
=
self
.
cond_hint
.
to
(
device
=
self
.
device
,
dtype
=
dtype
)
if
x_noisy
.
shape
[
0
]
!=
self
.
cond_hint
.
shape
[
0
]:
if
x_noisy
.
shape
[
0
]
!=
self
.
cond_hint
.
shape
[
0
]:
self
.
cond_hint
=
broadcast_image_to
(
self
.
cond_hint
,
x_noisy
.
shape
[
0
],
batched_number
)
self
.
cond_hint
=
broadcast_image_to
(
self
.
cond_hint
,
x_noisy
.
shape
[
0
],
batched_number
)
...
@@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd):
...
@@ -341,7 +359,9 @@ def load_controlnet_mmdit(sd):
if
len
(
unexpected
)
>
0
:
if
len
(
unexpected
)
>
0
:
logging
.
debug
(
"unexpected controlnet keys: {}"
.
format
(
unexpected
))
logging
.
debug
(
"unexpected controlnet keys: {}"
.
format
(
unexpected
))
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
latent_format
=
comfy
.
latent_formats
.
SD3
()
latent_format
.
shift_factor
=
0
#SD3 controlnet weirdness
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
latent_format
=
latent_format
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
return
control
return
control
...
...
comfy_extras/nodes_sd3.py
View file @
264caca2
...
@@ -80,8 +80,23 @@ class CLIPTextEncodeSD3:
...
@@ -80,8 +80,23 @@ class CLIPTextEncodeSD3:
return
([[
cond
,
{
"pooled_output"
:
pooled
}]],
)
return
([[
cond
,
{
"pooled_output"
:
pooled
}]],
)
class
ControlNetApplySD3
(
nodes
.
ControlNetApplyAdvanced
):
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"positive"
:
(
"CONDITIONING"
,
),
"negative"
:
(
"CONDITIONING"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"vae"
:
(
"VAE"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"start_percent"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end_percent"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
})
}}
CATEGORY
=
"_for_testing/sd3"
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"TripleCLIPLoader"
:
TripleCLIPLoader
,
"TripleCLIPLoader"
:
TripleCLIPLoader
,
"EmptySD3LatentImage"
:
EmptySD3LatentImage
,
"EmptySD3LatentImage"
:
EmptySD3LatentImage
,
"CLIPTextEncodeSD3"
:
CLIPTextEncodeSD3
,
"CLIPTextEncodeSD3"
:
CLIPTextEncodeSD3
,
"ControlNetApplySD3"
:
ControlNetApplySD3
,
}
}
nodes.py
View file @
264caca2
...
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
...
@@ -783,7 +783,7 @@ class ControlNetApplyAdvanced:
CATEGORY
=
"conditioning"
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start_percent
,
end_percent
):
def
apply_controlnet
(
self
,
positive
,
negative
,
control_net
,
image
,
strength
,
start_percent
,
end_percent
,
vae
=
None
):
if
strength
==
0
:
if
strength
==
0
:
return
(
positive
,
negative
)
return
(
positive
,
negative
)
...
@@ -800,7 +800,7 @@ class ControlNetApplyAdvanced:
...
@@ -800,7 +800,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
c_net
=
cnets
[
prev_cnet
]
else
:
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
))
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
)
,
vae
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
cnets
[
prev_cnet
]
=
c_net
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment