Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
2c9dba8d
Commit
2c9dba8d
authored
Nov 12, 2023
by
comfyanonymous
Browse files
sampling_function now has the model object as the argument.
parent
8d80584f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
comfy/samplers.py
comfy/samplers.py
+6
-6
No files found.
comfy/samplers.py
View file @
2c9dba8d
...
@@ -11,7 +11,7 @@ import comfy.conds
...
@@ -11,7 +11,7 @@ import comfy.conds
#The main sampling function shared by all the samplers
#The main sampling function shared by all the samplers
#Returns denoised
#Returns denoised
def
sampling_function
(
model
_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
def
sampling_function
(
model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
def
get_area_and_mult
(
conds
,
x_in
,
timestep_in
):
def
get_area_and_mult
(
conds
,
x_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
...
@@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
...
@@ -134,7 +134,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
return
out
return
out
def
calc_cond_uncond_batch
(
model
_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
model_options
):
def
calc_cond_uncond_batch
(
model
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
model_options
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
*
1e-37
out_count
=
torch
.
ones_like
(
x_in
)
*
1e-37
...
@@ -221,9 +221,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
...
@@ -221,9 +221,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
c
[
'transformer_options'
]
=
transformer_options
c
[
'transformer_options'
]
=
transformer_options
if
'model_function_wrapper'
in
model_options
:
if
'model_function_wrapper'
in
model_options
:
output
=
model_options
[
'model_function_wrapper'
](
model
_function
,
{
"input"
:
input_x
,
"timestep"
:
timestep_
,
"c"
:
c
,
"cond_or_uncond"
:
cond_or_uncond
}).
chunk
(
batch_chunks
)
output
=
model_options
[
'model_function_wrapper'
](
model
.
apply_model
,
{
"input"
:
input_x
,
"timestep"
:
timestep_
,
"c"
:
c
,
"cond_or_uncond"
:
cond_or_uncond
}).
chunk
(
batch_chunks
)
else
:
else
:
output
=
model
_function
(
input_x
,
timestep_
,
**
c
).
chunk
(
batch_chunks
)
output
=
model
.
apply_model
(
input_x
,
timestep_
,
**
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
for
o
in
range
(
batch_chunks
):
for
o
in
range
(
batch_chunks
):
...
@@ -246,7 +246,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
...
@@ -246,7 +246,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
if
math
.
isclose
(
cond_scale
,
1.0
):
if
math
.
isclose
(
cond_scale
,
1.0
):
uncond
=
None
uncond
=
None
cond
,
uncond
=
calc_cond_uncond_batch
(
model
_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
model_options
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
model_options
)
if
"sampler_cfg_function"
in
model_options
:
if
"sampler_cfg_function"
in
model_options
:
args
=
{
"cond"
:
x
-
cond
,
"uncond"
:
x
-
uncond
,
"cond_scale"
:
cond_scale
,
"timestep"
:
timestep
,
"input"
:
x
,
"sigma"
:
timestep
}
args
=
{
"cond"
:
x
-
cond
,
"uncond"
:
x
-
uncond
,
"cond_scale"
:
cond_scale
,
"timestep"
:
timestep
,
"input"
:
x
,
"sigma"
:
timestep
}
return
x
-
model_options
[
"sampler_cfg_function"
](
args
)
return
x
-
model_options
[
"sampler_cfg_function"
](
args
)
...
@@ -258,7 +258,7 @@ class CFGNoisePredictor(torch.nn.Module):
...
@@ -258,7 +258,7 @@ class CFGNoisePredictor(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
model_options
,
seed
=
seed
)
out
=
sampling_function
(
self
.
inner_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
model_options
,
seed
=
seed
)
return
out
return
out
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
apply_model
(
*
args
,
**
kwargs
)
return
self
.
apply_model
(
*
args
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment