Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3fce8881
"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "60ecbd0ba4ac4edac340de51af64a8f2bffe0fd0"
Commit
3fce8881
authored
Oct 24, 2023
by
comfyanonymous
Browse files
Sampling code refactor to make it easier to add more conds.
parent
5c65da31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
42 deletions
+65
-42
comfy/samplers.py
comfy/samplers.py
+65
-42
No files found.
comfy/samplers.py
View file @
3fce8881
...
...
@@ -9,9 +9,58 @@ import math
from
comfy
import
model_base
import
comfy.utils
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
class
CONDRegular
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
can_concat
(
self
,
other
):
if
self
.
cond
.
shape
!=
other
.
cond
.
shape
:
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
for
x
in
others
:
conds
.
append
(
x
.
cond
)
return
torch
.
cat
(
conds
)
class
CONDCrossAttn
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
can_concat
(
self
,
other
):
s1
=
self
.
cond
.
shape
s2
=
other
.
cond
.
shape
if
s1
!=
s2
:
if
s1
[
0
]
!=
s2
[
0
]
or
s1
[
2
]
!=
s2
[
2
]:
#these 2 cases should not happen
return
False
mult_min
=
lcm
(
s1
[
1
],
s2
[
1
])
diff
=
mult_min
//
min
(
s1
[
1
],
s2
[
1
])
if
diff
>
4
:
#arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
crossattn_max_len
=
self
.
cond
.
shape
[
1
]
for
x
in
others
:
c
=
x
.
cond
crossattn_max_len
=
lcm
(
crossattn_max_len
,
c
.
shape
[
1
])
conds
.
append
(
c
)
out
=
[]
for
c
in
conds
:
if
c
.
shape
[
1
]
<
crossattn_max_len
:
c
=
c
.
repeat
(
1
,
crossattn_max_len
//
c
.
shape
[
1
],
1
)
#padding with repeat doesn't change result
out
.
append
(
c
)
return
torch
.
cat
(
out
)
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
...
...
@@ -67,7 +116,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult
[:,:,:,
area
[
1
]
-
1
-
t
:
area
[
1
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
conditionning
=
{}
conditionning
[
'c_crossattn'
]
=
cond
[
0
]
conditionning
[
'c_crossattn'
]
=
CONDCrossAttn
(
cond
[
0
]
)
if
'concat'
in
cond
[
1
]:
cond_concat_in
=
cond
[
1
][
'concat'
]
...
...
@@ -76,10 +125,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
for
x
in
cond_concat_in
:
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cropped
.
append
(
cr
)
conditionning
[
'c_concat'
]
=
torch
.
cat
(
cropped
,
dim
=
1
)
conditionning
[
'c_concat'
]
=
CONDRegular
(
torch
.
cat
(
cropped
,
dim
=
1
)
)
if
adm_cond
is
not
None
:
conditionning
[
'c_adm'
]
=
adm_cond
conditionning
[
'c_adm'
]
=
CONDRegular
(
adm_cond
)
control
=
None
if
'control'
in
cond
[
1
]:
...
...
@@ -105,22 +154,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return
True
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
if
'c_crossattn'
in
c1
:
s1
=
c1
[
'c_crossattn'
].
shape
s2
=
c2
[
'c_crossattn'
].
shape
if
s1
!=
s2
:
if
s1
[
0
]
!=
s2
[
0
]
or
s1
[
2
]
!=
s2
[
2
]:
#these 2 cases should not happen
return
False
mult_min
=
lcm
(
s1
[
1
],
s2
[
1
])
diff
=
mult_min
//
min
(
s1
[
1
],
s2
[
1
])
if
diff
>
4
:
#arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return
False
if
'c_concat'
in
c1
:
if
c1
[
'c_concat'
].
shape
!=
c2
[
'c_concat'
].
shape
:
return
False
if
'c_adm'
in
c1
:
if
c1
[
'c_adm'
].
shape
!=
c2
[
'c_adm'
].
shape
:
for
k
in
c1
:
if
not
c1
[
k
].
can_concat
(
c2
[
k
]):
return
False
return
True
...
...
@@ -149,31 +184,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c_concat
=
[]
c_adm
=
[]
crossattn_max_len
=
0
temp
=
{}
for
x
in
c_list
:
if
'c_crossattn'
in
x
:
c
=
x
[
'c_crossattn'
]
if
crossattn_max_len
==
0
:
crossattn_max_len
=
c
.
shape
[
1
]
else
:
crossattn_max_len
=
lcm
(
crossattn_max_len
,
c
.
shape
[
1
])
c_crossattn
.
append
(
c
)
if
'c_concat'
in
x
:
c_concat
.
append
(
x
[
'c_concat'
])
if
'c_adm'
in
x
:
c_adm
.
append
(
x
[
'c_adm'
])
for
k
in
x
:
cur
=
temp
.
get
(
k
,
[])
cur
.
append
(
x
[
k
])
temp
[
k
]
=
cur
out
=
{}
c_crossattn_out
=
[]
for
c
in
c_crossattn
:
if
c
.
shape
[
1
]
<
crossattn_max_len
:
c
=
c
.
repeat
(
1
,
crossattn_max_len
//
c
.
shape
[
1
],
1
)
#padding with repeat doesn't change result
c_crossattn_out
.
append
(
c
)
if
len
(
c_crossattn_out
)
>
0
:
out
[
'c_crossattn'
]
=
torch
.
cat
(
c_crossattn_out
)
if
len
(
c_concat
)
>
0
:
out
[
'c_concat'
]
=
torch
.
cat
(
c_concat
)
if
len
(
c_adm
)
>
0
:
out
[
'c_adm'
]
=
torch
.
cat
(
c_adm
)
for
k
in
temp
:
conds
=
temp
[
k
]
out
[
k
]
=
conds
[
0
].
concat
(
conds
[
1
:])
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
model_options
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment