Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
d1d2fea8
Commit
d1d2fea8
authored
Oct 25, 2023
by
comfyanonymous
Browse files
Pass extra conds directly to unet.
parent
036f88c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
comfy/model_base.py
comfy/model_base.py
+6
-5
No files found.
comfy/model_base.py
View file @
d1d2fea8
...
@@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module):
...
@@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module):
self
.
register_buffer
(
'alphas_cumprod'
,
torch
.
tensor
(
alphas_cumprod
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'alphas_cumprod'
,
torch
.
tensor
(
alphas_cumprod
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
torch
.
tensor
(
alphas_cumprod_prev
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
torch
.
tensor
(
alphas_cumprod_prev
,
dtype
=
torch
.
float32
))
def
apply_model
(
self
,
x
,
t
,
c_concat
=
None
,
c_crossattn
=
None
,
c_adm
=
None
,
control
=
None
,
transformer_options
=
{},
**
kwargs
):
def
apply_model
(
self
,
x
,
t
,
c_concat
=
None
,
c_crossattn
=
None
,
control
=
None
,
transformer_options
=
{},
**
kwargs
):
if
c_concat
is
not
None
:
if
c_concat
is
not
None
:
xc
=
torch
.
cat
([
x
]
+
[
c_concat
],
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
[
c_concat
],
dim
=
1
)
else
:
else
:
...
@@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module):
...
@@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module):
xc
=
xc
.
to
(
dtype
)
xc
=
xc
.
to
(
dtype
)
t
=
t
.
to
(
dtype
)
t
=
t
.
to
(
dtype
)
context
=
context
.
to
(
dtype
)
context
=
context
.
to
(
dtype
)
if
c_adm
is
not
None
:
extra_conds
=
{}
c_adm
=
c_adm
.
to
(
dtype
)
for
o
in
kwargs
:
return
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
).
float
()
extra_conds
[
o
]
=
kwargs
[
o
].
to
(
dtype
)
return
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
def
get_dtype
(
self
):
def
get_dtype
(
self
):
return
self
.
diffusion_model
.
dtype
return
self
.
diffusion_model
.
dtype
...
@@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module):
...
@@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module):
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
data
)
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
data
)
adm
=
self
.
encode_adm
(
**
kwargs
)
adm
=
self
.
encode_adm
(
**
kwargs
)
if
adm
is
not
None
:
if
adm
is
not
None
:
out
[
'
c_adm
'
]
=
comfy
.
conds
.
CONDRegular
(
adm
)
out
[
'
y
'
]
=
comfy
.
conds
.
CONDRegular
(
adm
)
return
out
return
out
def
load_model_weights
(
self
,
sd
,
unet_prefix
=
""
):
def
load_model_weights
(
self
,
sd
,
unet_prefix
=
""
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment