Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
3fce8881
Commit
3fce8881
authored
Oct 24, 2023
by
comfyanonymous
Browse files
Sampling code refactor to make it easier to add more conds.
parent
5c65da31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
42 deletions
+65
-42
comfy/samplers.py
comfy/samplers.py
+65
-42
No files found.
comfy/samplers.py
View file @
3fce8881
...
...
@@ -9,9 +9,58 @@ import math
from
comfy
import
model_base
import
comfy.utils
def
lcm
(
a
,
b
):
#TODO: eventually replace by math.lcm (added in python3.9)
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
class
CONDRegular
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
can_concat
(
self
,
other
):
if
self
.
cond
.
shape
!=
other
.
cond
.
shape
:
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
for
x
in
others
:
conds
.
append
(
x
.
cond
)
return
torch
.
cat
(
conds
)
class
CONDCrossAttn
:
def
__init__
(
self
,
cond
):
self
.
cond
=
cond
def
can_concat
(
self
,
other
):
s1
=
self
.
cond
.
shape
s2
=
other
.
cond
.
shape
if
s1
!=
s2
:
if
s1
[
0
]
!=
s2
[
0
]
or
s1
[
2
]
!=
s2
[
2
]:
#these 2 cases should not happen
return
False
mult_min
=
lcm
(
s1
[
1
],
s2
[
1
])
diff
=
mult_min
//
min
(
s1
[
1
],
s2
[
1
])
if
diff
>
4
:
#arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return
False
return
True
def
concat
(
self
,
others
):
conds
=
[
self
.
cond
]
crossattn_max_len
=
self
.
cond
.
shape
[
1
]
for
x
in
others
:
c
=
x
.
cond
crossattn_max_len
=
lcm
(
crossattn_max_len
,
c
.
shape
[
1
])
conds
.
append
(
c
)
out
=
[]
for
c
in
conds
:
if
c
.
shape
[
1
]
<
crossattn_max_len
:
c
=
c
.
repeat
(
1
,
crossattn_max_len
//
c
.
shape
[
1
],
1
)
#padding with repeat doesn't change result
out
.
append
(
c
)
return
torch
.
cat
(
out
)
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
...
...
@@ -67,7 +116,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
mult
[:,:,:,
area
[
1
]
-
1
-
t
:
area
[
1
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
conditionning
=
{}
conditionning
[
'c_crossattn'
]
=
cond
[
0
]
conditionning
[
'c_crossattn'
]
=
CONDCrossAttn
(
cond
[
0
]
)
if
'concat'
in
cond
[
1
]:
cond_concat_in
=
cond
[
1
][
'concat'
]
...
...
@@ -76,10 +125,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
for
x
in
cond_concat_in
:
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cropped
.
append
(
cr
)
conditionning
[
'c_concat'
]
=
torch
.
cat
(
cropped
,
dim
=
1
)
conditionning
[
'c_concat'
]
=
CONDRegular
(
torch
.
cat
(
cropped
,
dim
=
1
)
)
if
adm_cond
is
not
None
:
conditionning
[
'c_adm'
]
=
adm_cond
conditionning
[
'c_adm'
]
=
CONDRegular
(
adm_cond
)
control
=
None
if
'control'
in
cond
[
1
]:
...
...
@@ -105,22 +154,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return
True
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
if
'c_crossattn'
in
c1
:
s1
=
c1
[
'c_crossattn'
].
shape
s2
=
c2
[
'c_crossattn'
].
shape
if
s1
!=
s2
:
if
s1
[
0
]
!=
s2
[
0
]
or
s1
[
2
]
!=
s2
[
2
]:
#these 2 cases should not happen
return
False
mult_min
=
lcm
(
s1
[
1
],
s2
[
1
])
diff
=
mult_min
//
min
(
s1
[
1
],
s2
[
1
])
if
diff
>
4
:
#arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return
False
if
'c_concat'
in
c1
:
if
c1
[
'c_concat'
].
shape
!=
c2
[
'c_concat'
].
shape
:
return
False
if
'c_adm'
in
c1
:
if
c1
[
'c_adm'
].
shape
!=
c2
[
'c_adm'
].
shape
:
for
k
in
c1
:
if
not
c1
[
k
].
can_concat
(
c2
[
k
]):
return
False
return
True
...
...
@@ -149,31 +184,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c_concat
=
[]
c_adm
=
[]
crossattn_max_len
=
0
temp
=
{}
for
x
in
c_list
:
if
'c_crossattn'
in
x
:
c
=
x
[
'c_crossattn'
]
if
crossattn_max_len
==
0
:
crossattn_max_len
=
c
.
shape
[
1
]
else
:
crossattn_max_len
=
lcm
(
crossattn_max_len
,
c
.
shape
[
1
])
c_crossattn
.
append
(
c
)
if
'c_concat'
in
x
:
c_concat
.
append
(
x
[
'c_concat'
])
if
'c_adm'
in
x
:
c_adm
.
append
(
x
[
'c_adm'
])
for
k
in
x
:
cur
=
temp
.
get
(
k
,
[])
cur
.
append
(
x
[
k
])
temp
[
k
]
=
cur
out
=
{}
c_crossattn_out
=
[]
for
c
in
c_crossattn
:
if
c
.
shape
[
1
]
<
crossattn_max_len
:
c
=
c
.
repeat
(
1
,
crossattn_max_len
//
c
.
shape
[
1
],
1
)
#padding with repeat doesn't change result
c_crossattn_out
.
append
(
c
)
if
len
(
c_crossattn_out
)
>
0
:
out
[
'c_crossattn'
]
=
torch
.
cat
(
c_crossattn_out
)
if
len
(
c_concat
)
>
0
:
out
[
'c_concat'
]
=
torch
.
cat
(
c_concat
)
if
len
(
c_adm
)
>
0
:
out
[
'c_adm'
]
=
torch
.
cat
(
c_adm
)
for
k
in
temp
:
conds
=
temp
[
k
]
out
[
k
]
=
conds
[
0
].
concat
(
conds
[
1
:])
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
model_options
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment