Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
111f1b52
Commit
111f1b52
authored
Oct 31, 2023
by
comfyanonymous
Browse files
Fix some issues with sampling precision.
parent
7c0f255d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
4 deletions
+6
-4
comfy/model_base.py
comfy/model_base.py
+2
-2
comfy/samplers.py
comfy/samplers.py
+4
-2
No files found.
comfy/model_base.py
View file @
111f1b52
...
...
@@ -44,7 +44,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
else
:
betas
=
make_beta_schedule
(
beta_schedule
,
timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
,
cosine_s
=
cosine_s
)
alphas
=
1.
-
betas
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod
=
torch
.
tensor
(
np
.
cumprod
(
alphas
,
axis
=
0
)
,
dtype
=
torch
.
float32
)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps
,
=
betas
.
shape
...
...
@@ -56,7 +56,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas
=
torch
.
tensor
(
((
1
-
alphas_cumprod
)
/
alphas_cumprod
)
**
0.5
,
dtype
=
torch
.
float32
)
sigmas
=
((
1
-
alphas_cumprod
)
/
alphas_cumprod
)
**
0.5
self
.
register_buffer
(
'sigmas'
,
sigmas
)
self
.
register_buffer
(
'log_sigmas'
,
sigmas
.
log
())
...
...
comfy/samplers.py
View file @
111f1b52
...
...
@@ -137,10 +137,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
model_options
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
one
s_like
(
x_in
)
/
100000.0
out_count
=
torch
.
zero
s_like
(
x_in
)
out_uncond
=
torch
.
zeros_like
(
x_in
)
out_uncond_count
=
torch
.
one
s_like
(
x_in
)
/
100000.0
out_uncond_count
=
torch
.
zero
s_like
(
x_in
)
COND
=
0
UNCOND
=
1
...
...
@@ -241,6 +241,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
out_uncond
/=
out_uncond_count
del
out_uncond_count
torch
.
nan_to_num
(
out_cond
,
nan
=
0.0
,
posinf
=
0.0
,
neginf
=
0.0
,
out
=
out_cond
)
#in case out_count or out_uncond_count had some zeros
torch
.
nan_to_num
(
out_uncond
,
nan
=
0.0
,
posinf
=
0.0
,
neginf
=
0.0
,
out
=
out_uncond
)
return
out_cond
,
out_uncond
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment