Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
8c649357
"vscode:/vscode.git/clone" did not exist on "a107993f106cecb1c375f7a6ae41088d04f29e29"
Commit
8c649357
authored
Jan 03, 2024
by
comfyanonymous
Browse files
Implement noise augmentation for SD 4X upscale model.
parent
ef4f6037
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
33 additions
and
14 deletions
+33
-14
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+1
-1
comfy/ldm/modules/diffusionmodules/upscaling.py
comfy/ldm/modules/diffusionmodules/upscaling.py
+8
-4
comfy/model_base.py
comfy/model_base.py
+17
-5
comfy/samplers.py
comfy/samplers.py
+2
-2
comfy/supported_models.py
comfy/supported_models.py
+1
-0
comfy_extras/nodes_sdupscale.py
comfy_extras/nodes_sdupscale.py
+4
-2
No files found.
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
8c649357
...
@@ -498,7 +498,7 @@ class UNetModel(nn.Module):
...
@@ -498,7 +498,7 @@ class UNetModel(nn.Module):
if
self
.
num_classes
is
not
None
:
if
self
.
num_classes
is
not
None
:
if
isinstance
(
self
.
num_classes
,
int
):
if
isinstance
(
self
.
num_classes
,
int
):
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
)
self
.
label_emb
=
nn
.
Embedding
(
num_classes
,
time_embed_dim
,
dtype
=
self
.
dtype
,
device
=
device
)
elif
self
.
num_classes
==
"continuous"
:
elif
self
.
num_classes
==
"continuous"
:
print
(
"setting up linear c_adm embedding layer"
)
print
(
"setting up linear c_adm embedding layer"
)
self
.
label_emb
=
nn
.
Linear
(
1
,
time_embed_dim
)
self
.
label_emb
=
nn
.
Linear
(
1
,
time_embed_dim
)
...
...
comfy/ldm/modules/diffusionmodules/upscaling.py
View file @
8c649357
...
@@ -41,8 +41,12 @@ class AbstractLowScaleModel(nn.Module):
...
@@ -41,8 +41,12 @@ class AbstractLowScaleModel(nn.Module):
self
.
register_buffer
(
'sqrt_recip_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_recip_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_recipm1_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
-
1
)))
self
.
register_buffer
(
'sqrt_recipm1_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
-
1
)))
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
,
seed
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
if
noise
is
None
:
if
seed
is
None
:
noise
=
torch
.
randn_like
(
x_start
)
else
:
noise
=
torch
.
randn
(
x_start
.
size
(),
dtype
=
x_start
.
dtype
,
layout
=
x_start
.
layout
,
generator
=
torch
.
manual_seed
(
seed
)).
to
(
x_start
.
device
)
return
(
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
.
to
(
x_start
.
device
),
t
,
x_start
.
shape
)
*
x_start
+
return
(
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
.
to
(
x_start
.
device
),
t
,
x_start
.
shape
)
*
x_start
+
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
.
to
(
x_start
.
device
),
t
,
x_start
.
shape
)
*
noise
)
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
.
to
(
x_start
.
device
),
t
,
x_start
.
shape
)
*
noise
)
...
@@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
...
@@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super
().
__init__
(
noise_schedule_config
=
noise_schedule_config
)
super
().
__init__
(
noise_schedule_config
=
noise_schedule_config
)
self
.
max_noise_level
=
max_noise_level
self
.
max_noise_level
=
max_noise_level
def
forward
(
self
,
x
,
noise_level
=
None
):
def
forward
(
self
,
x
,
noise_level
=
None
,
seed
=
None
):
if
noise_level
is
None
:
if
noise_level
is
None
:
noise_level
=
torch
.
randint
(
0
,
self
.
max_noise_level
,
(
x
.
shape
[
0
],),
device
=
x
.
device
).
long
()
noise_level
=
torch
.
randint
(
0
,
self
.
max_noise_level
,
(
x
.
shape
[
0
],),
device
=
x
.
device
).
long
()
else
:
else
:
assert
isinstance
(
noise_level
,
torch
.
Tensor
)
assert
isinstance
(
noise_level
,
torch
.
Tensor
)
z
=
self
.
q_sample
(
x
,
noise_level
)
z
=
self
.
q_sample
(
x
,
noise_level
,
seed
=
seed
)
return
z
,
noise_level
return
z
,
noise_level
...
...
comfy/model_base.py
View file @
8c649357
import
torch
import
torch
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
Timestep
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.
openaimodel
import
Timestep
from
comfy.ldm.modules.diffusionmodules.
upscaling
import
ImageConcatWithNoiseAugmentation
import
comfy.model_management
import
comfy.model_management
import
comfy.conds
import
comfy.conds
import
comfy.ops
import
comfy.ops
...
@@ -78,8 +78,9 @@ class BaseModel(torch.nn.Module):
...
@@ -78,8 +78,9 @@ class BaseModel(torch.nn.Module):
extra_conds
=
{}
extra_conds
=
{}
for
o
in
kwargs
:
for
o
in
kwargs
:
extra
=
kwargs
[
o
]
extra
=
kwargs
[
o
]
if
hasattr
(
extra
,
"to"
):
if
hasattr
(
extra
,
"dtype"
):
extra
=
extra
.
to
(
dtype
)
if
extra
.
dtype
!=
torch
.
int
and
extra
.
dtype
!=
torch
.
long
:
extra
=
extra
.
to
(
dtype
)
extra_conds
[
o
]
=
extra
extra_conds
[
o
]
=
extra
model_output
=
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
model_output
=
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
control
=
control
,
transformer_options
=
transformer_options
,
**
extra_conds
).
float
()
...
@@ -368,20 +369,31 @@ class Stable_Zero123(BaseModel):
...
@@ -368,20 +369,31 @@ class Stable_Zero123(BaseModel):
class
SD_X4Upscaler
(
BaseModel
):
class
SD_X4Upscaler
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
V_PREDICTION
,
device
=
None
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
V_PREDICTION
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
noise_augmentor
=
ImageConcatWithNoiseAugmentation
(
noise_schedule_config
=
{
"linear_start"
:
0.0001
,
"linear_end"
:
0.02
},
max_noise_level
=
350
)
def
extra_conds
(
self
,
**
kwargs
):
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
out
=
{}
image
=
kwargs
.
get
(
"concat_image"
,
None
)
image
=
kwargs
.
get
(
"concat_image"
,
None
)
noise
=
kwargs
.
get
(
"noise"
,
None
)
noise
=
kwargs
.
get
(
"noise"
,
None
)
noise_augment
=
kwargs
.
get
(
"noise_augmentation"
,
0.0
)
device
=
kwargs
[
"device"
]
seed
=
kwargs
[
"seed"
]
-
10
noise_level
=
round
((
self
.
noise_augmentor
.
max_noise_level
)
*
noise_augment
)
if
image
is
None
:
if
image
is
None
:
image
=
torch
.
zeros_like
(
noise
)[:,:
3
]
image
=
torch
.
zeros_like
(
noise
)[:,:
3
]
if
image
.
shape
[
1
:]
!=
noise
.
shape
[
1
:]:
if
image
.
shape
[
1
:]
!=
noise
.
shape
[
1
:]:
image
=
utils
.
common_upscale
(
image
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
image
=
utils
.
common_upscale
(
image
.
to
(
device
),
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
noise_level
=
torch
.
tensor
([
noise_level
],
device
=
device
)
if
noise_augment
>
0
:
image
,
noise_level
=
self
.
noise_augmentor
(
image
.
to
(
device
),
noise_level
=
noise_level
,
seed
=
seed
)
image
=
utils
.
resize_to_batch_size
(
image
,
noise
.
shape
[
0
])
image
=
utils
.
resize_to_batch_size
(
image
,
noise
.
shape
[
0
])
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
image
)
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
image
)
out
[
'y'
]
=
comfy
.
conds
.
CONDRegular
(
noise_level
)
return
out
return
out
comfy/samplers.py
View file @
8c649357
...
@@ -603,8 +603,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
...
@@ -603,8 +603,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
latent_image
=
model
.
process_latent_in
(
latent_image
)
latent_image
=
model
.
process_latent_in
(
latent_image
)
if
hasattr
(
model
,
'extra_conds'
):
if
hasattr
(
model
,
'extra_conds'
):
positive
=
encode_model_conds
(
model
.
extra_conds
,
positive
,
noise
,
device
,
"positive"
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
positive
=
encode_model_conds
(
model
.
extra_conds
,
positive
,
noise
,
device
,
"positive"
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
,
seed
=
seed
)
negative
=
encode_model_conds
(
model
.
extra_conds
,
negative
,
noise
,
device
,
"negative"
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
)
negative
=
encode_model_conds
(
model
.
extra_conds
,
negative
,
noise
,
device
,
"negative"
,
latent_image
=
latent_image
,
denoise_mask
=
denoise_mask
,
seed
=
seed
)
#make sure each cond area has an opposite one with the same area
#make sure each cond area has an opposite one with the same area
for
c
in
positive
:
for
c
in
positive
:
...
...
comfy/supported_models.py
View file @
8c649357
...
@@ -290,6 +290,7 @@ class SD_X4Upscaler(SD20):
...
@@ -290,6 +290,7 @@ class SD_X4Upscaler(SD20):
unet_extra_config
=
{
unet_extra_config
=
{
"disable_self_attentions"
:
[
True
,
True
,
True
,
False
],
"disable_self_attentions"
:
[
True
,
True
,
True
,
False
],
"num_classes"
:
1000
,
"num_heads"
:
8
,
"num_heads"
:
8
,
"num_head_channels"
:
-
1
,
"num_head_channels"
:
-
1
,
}
}
...
...
comfy_extras/nodes_sdupscale.py
View file @
8c649357
...
@@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning:
...
@@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning:
"positive"
:
(
"CONDITIONING"
,),
"positive"
:
(
"CONDITIONING"
,),
"negative"
:
(
"CONDITIONING"
,),
"negative"
:
(
"CONDITIONING"
,),
"scale_ratio"
:
(
"FLOAT"
,
{
"default"
:
4.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
"scale_ratio"
:
(
"FLOAT"
,
{
"default"
:
4.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
#
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1
0
.0, "step": 0.01}),
#TODO
"noise_augmentation"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.
0
01
}),
}}
}}
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
,
"LATENT"
)
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
,
"LATENT"
)
RETURN_NAMES
=
(
"positive"
,
"negative"
,
"latent"
)
RETURN_NAMES
=
(
"positive"
,
"negative"
,
"latent"
)
...
@@ -18,7 +18,7 @@ class SD_4XUpscale_Conditioning:
...
@@ -18,7 +18,7 @@ class SD_4XUpscale_Conditioning:
CATEGORY
=
"conditioning/upscale_diffusion"
CATEGORY
=
"conditioning/upscale_diffusion"
def
encode
(
self
,
images
,
positive
,
negative
,
scale_ratio
):
def
encode
(
self
,
images
,
positive
,
negative
,
scale_ratio
,
noise_augmentation
):
width
=
max
(
1
,
round
(
images
.
shape
[
-
2
]
*
scale_ratio
))
width
=
max
(
1
,
round
(
images
.
shape
[
-
2
]
*
scale_ratio
))
height
=
max
(
1
,
round
(
images
.
shape
[
-
3
]
*
scale_ratio
))
height
=
max
(
1
,
round
(
images
.
shape
[
-
3
]
*
scale_ratio
))
...
@@ -30,11 +30,13 @@ class SD_4XUpscale_Conditioning:
...
@@ -30,11 +30,13 @@ class SD_4XUpscale_Conditioning:
for
t
in
positive
:
for
t
in
positive
:
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
[
1
][
'concat_image'
]
=
pixels
n
[
1
][
'concat_image'
]
=
pixels
n
[
1
][
'noise_augmentation'
]
=
noise_augmentation
out_cp
.
append
(
n
)
out_cp
.
append
(
n
)
for
t
in
negative
:
for
t
in
negative
:
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
[
1
][
'concat_image'
]
=
pixels
n
[
1
][
'concat_image'
]
=
pixels
n
[
1
][
'noise_augmentation'
]
=
noise_augmentation
out_cn
.
append
(
n
)
out_cn
.
append
(
n
)
latent
=
torch
.
zeros
([
images
.
shape
[
0
],
4
,
height
//
4
,
width
//
4
])
latent
=
torch
.
zeros
([
images
.
shape
[
0
],
4
,
height
//
4
,
width
//
4
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment