Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
2b14041d
Commit
2b14041d
authored
Jun 13, 2023
by
comfyanonymous
Browse files
Remove useless code.
parent
274dff32
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
0 additions
and
870 deletions
+0
-870
comfy/k_diffusion/augmentation.py
comfy/k_diffusion/augmentation.py
+0
-105
comfy/k_diffusion/config.py
comfy/k_diffusion/config.py
+0
-110
comfy/k_diffusion/evaluation.py
comfy/k_diffusion/evaluation.py
+0
-134
comfy/k_diffusion/gns.py
comfy/k_diffusion/gns.py
+0
-99
comfy/k_diffusion/layers.py
comfy/k_diffusion/layers.py
+0
-246
comfy/k_diffusion/models/__init__.py
comfy/k_diffusion/models/__init__.py
+0
-1
comfy/k_diffusion/models/image_v1.py
comfy/k_diffusion/models/image_v1.py
+0
-156
comfy/k_diffusion/utils.py
comfy/k_diffusion/utils.py
+0
-19
No files found.
comfy/k_diffusion/augmentation.py
deleted
100644 → 0
View file @
274dff32
from
functools
import
reduce
import
math
import
operator
import
numpy
as
np
from
skimage
import
transform
import
torch
from
torch
import
nn
def
translate2d
(
tx
,
ty
):
mat
=
[[
1
,
0
,
tx
],
[
0
,
1
,
ty
],
[
0
,
0
,
1
]]
return
torch
.
tensor
(
mat
,
dtype
=
torch
.
float32
)
def
scale2d
(
sx
,
sy
):
mat
=
[[
sx
,
0
,
0
],
[
0
,
sy
,
0
],
[
0
,
0
,
1
]]
return
torch
.
tensor
(
mat
,
dtype
=
torch
.
float32
)
def
rotate2d
(
theta
):
mat
=
[[
torch
.
cos
(
theta
),
torch
.
sin
(
-
theta
),
0
],
[
torch
.
sin
(
theta
),
torch
.
cos
(
theta
),
0
],
[
0
,
0
,
1
]]
return
torch
.
tensor
(
mat
,
dtype
=
torch
.
float32
)
class
KarrasAugmentationPipeline
:
def
__init__
(
self
,
a_prob
=
0.12
,
a_scale
=
2
**
0.2
,
a_aniso
=
2
**
0.2
,
a_trans
=
1
/
8
):
self
.
a_prob
=
a_prob
self
.
a_scale
=
a_scale
self
.
a_aniso
=
a_aniso
self
.
a_trans
=
a_trans
def
__call__
(
self
,
image
):
h
,
w
=
image
.
size
mats
=
[
translate2d
(
h
/
2
-
0.5
,
w
/
2
-
0.5
)]
# x-flip
a0
=
torch
.
randint
(
2
,
[]).
float
()
mats
.
append
(
scale2d
(
1
-
2
*
a0
,
1
))
# y-flip
do
=
(
torch
.
rand
([])
<
self
.
a_prob
).
float
()
a1
=
torch
.
randint
(
2
,
[]).
float
()
*
do
mats
.
append
(
scale2d
(
1
,
1
-
2
*
a1
))
# scaling
do
=
(
torch
.
rand
([])
<
self
.
a_prob
).
float
()
a2
=
torch
.
randn
([])
*
do
mats
.
append
(
scale2d
(
self
.
a_scale
**
a2
,
self
.
a_scale
**
a2
))
# rotation
do
=
(
torch
.
rand
([])
<
self
.
a_prob
).
float
()
a3
=
(
torch
.
rand
([])
*
2
*
math
.
pi
-
math
.
pi
)
*
do
mats
.
append
(
rotate2d
(
-
a3
))
# anisotropy
do
=
(
torch
.
rand
([])
<
self
.
a_prob
).
float
()
a4
=
(
torch
.
rand
([])
*
2
*
math
.
pi
-
math
.
pi
)
*
do
a5
=
torch
.
randn
([])
*
do
mats
.
append
(
rotate2d
(
a4
))
mats
.
append
(
scale2d
(
self
.
a_aniso
**
a5
,
self
.
a_aniso
**
-
a5
))
mats
.
append
(
rotate2d
(
-
a4
))
# translation
do
=
(
torch
.
rand
([])
<
self
.
a_prob
).
float
()
a6
=
torch
.
randn
([])
*
do
a7
=
torch
.
randn
([])
*
do
mats
.
append
(
translate2d
(
self
.
a_trans
*
w
*
a6
,
self
.
a_trans
*
h
*
a7
))
# form the transformation matrix and conditioning vector
mats
.
append
(
translate2d
(
-
h
/
2
+
0.5
,
-
w
/
2
+
0.5
))
mat
=
reduce
(
operator
.
matmul
,
mats
)
cond
=
torch
.
stack
([
a0
,
a1
,
a2
,
a3
.
cos
()
-
1
,
a3
.
sin
(),
a5
*
a4
.
cos
(),
a5
*
a4
.
sin
(),
a6
,
a7
])
# apply the transformation
image_orig
=
np
.
array
(
image
,
dtype
=
np
.
float32
)
/
255
if
image_orig
.
ndim
==
2
:
image_orig
=
image_orig
[...,
None
]
tf
=
transform
.
AffineTransform
(
mat
.
numpy
())
image
=
transform
.
warp
(
image_orig
,
tf
.
inverse
,
order
=
3
,
mode
=
'reflect'
,
cval
=
0.5
,
clip
=
False
,
preserve_range
=
True
)
image_orig
=
torch
.
as_tensor
(
image_orig
).
movedim
(
2
,
0
)
*
2
-
1
image
=
torch
.
as_tensor
(
image
).
movedim
(
2
,
0
)
*
2
-
1
return
image
,
image_orig
,
cond
class
KarrasAugmentWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
inner_model
=
model
def
forward
(
self
,
input
,
sigma
,
aug_cond
=
None
,
mapping_cond
=
None
,
**
kwargs
):
if
aug_cond
is
None
:
aug_cond
=
input
.
new_zeros
([
input
.
shape
[
0
],
9
])
if
mapping_cond
is
None
:
mapping_cond
=
aug_cond
else
:
mapping_cond
=
torch
.
cat
([
aug_cond
,
mapping_cond
],
dim
=
1
)
return
self
.
inner_model
(
input
,
sigma
,
mapping_cond
=
mapping_cond
,
**
kwargs
)
def
set_skip_stages
(
self
,
skip_stages
):
return
self
.
inner_model
.
set_skip_stages
(
skip_stages
)
def
set_patch_size
(
self
,
patch_size
):
return
self
.
inner_model
.
set_patch_size
(
patch_size
)
comfy/k_diffusion/config.py
deleted
100644 → 0
View file @
274dff32
from
functools
import
partial
import
json
import
math
import
warnings
from
jsonmerge
import
merge
from
.
import
augmentation
,
layers
,
models
,
utils
def
load_config
(
file
):
defaults
=
{
'model'
:
{
'sigma_data'
:
1.
,
'patch_size'
:
1
,
'dropout_rate'
:
0.
,
'augment_wrapper'
:
True
,
'augment_prob'
:
0.
,
'mapping_cond_dim'
:
0
,
'unet_cond_dim'
:
0
,
'cross_cond_dim'
:
0
,
'cross_attn_depths'
:
None
,
'skip_stages'
:
0
,
'has_variance'
:
False
,
},
'dataset'
:
{
'type'
:
'imagefolder'
,
},
'optimizer'
:
{
'type'
:
'adamw'
,
'lr'
:
1e-4
,
'betas'
:
[
0.95
,
0.999
],
'eps'
:
1e-6
,
'weight_decay'
:
1e-3
,
},
'lr_sched'
:
{
'type'
:
'inverse'
,
'inv_gamma'
:
20000.
,
'power'
:
1.
,
'warmup'
:
0.99
,
},
'ema_sched'
:
{
'type'
:
'inverse'
,
'power'
:
0.6667
,
'max_value'
:
0.9999
},
}
config
=
json
.
load
(
file
)
return
merge
(
defaults
,
config
)
def
make_model
(
config
):
config
=
config
[
'model'
]
assert
config
[
'type'
]
==
'image_v1'
model
=
models
.
ImageDenoiserModelV1
(
config
[
'input_channels'
],
config
[
'mapping_out'
],
config
[
'depths'
],
config
[
'channels'
],
config
[
'self_attn_depths'
],
config
[
'cross_attn_depths'
],
patch_size
=
config
[
'patch_size'
],
dropout_rate
=
config
[
'dropout_rate'
],
mapping_cond_dim
=
config
[
'mapping_cond_dim'
]
+
(
9
if
config
[
'augment_wrapper'
]
else
0
),
unet_cond_dim
=
config
[
'unet_cond_dim'
],
cross_cond_dim
=
config
[
'cross_cond_dim'
],
skip_stages
=
config
[
'skip_stages'
],
has_variance
=
config
[
'has_variance'
],
)
if
config
[
'augment_wrapper'
]:
model
=
augmentation
.
KarrasAugmentWrapper
(
model
)
return
model
def
make_denoiser_wrapper
(
config
):
config
=
config
[
'model'
]
sigma_data
=
config
.
get
(
'sigma_data'
,
1.
)
has_variance
=
config
.
get
(
'has_variance'
,
False
)
if
not
has_variance
:
return
partial
(
layers
.
Denoiser
,
sigma_data
=
sigma_data
)
return
partial
(
layers
.
DenoiserWithVariance
,
sigma_data
=
sigma_data
)
def
make_sample_density
(
config
):
sd_config
=
config
[
'sigma_sample_density'
]
sigma_data
=
config
[
'sigma_data'
]
if
sd_config
[
'type'
]
==
'lognormal'
:
loc
=
sd_config
[
'mean'
]
if
'mean'
in
sd_config
else
sd_config
[
'loc'
]
scale
=
sd_config
[
'std'
]
if
'std'
in
sd_config
else
sd_config
[
'scale'
]
return
partial
(
utils
.
rand_log_normal
,
loc
=
loc
,
scale
=
scale
)
if
sd_config
[
'type'
]
==
'loglogistic'
:
loc
=
sd_config
[
'loc'
]
if
'loc'
in
sd_config
else
math
.
log
(
sigma_data
)
scale
=
sd_config
[
'scale'
]
if
'scale'
in
sd_config
else
0.5
min_value
=
sd_config
[
'min_value'
]
if
'min_value'
in
sd_config
else
0.
max_value
=
sd_config
[
'max_value'
]
if
'max_value'
in
sd_config
else
float
(
'inf'
)
return
partial
(
utils
.
rand_log_logistic
,
loc
=
loc
,
scale
=
scale
,
min_value
=
min_value
,
max_value
=
max_value
)
if
sd_config
[
'type'
]
==
'loguniform'
:
min_value
=
sd_config
[
'min_value'
]
if
'min_value'
in
sd_config
else
config
[
'sigma_min'
]
max_value
=
sd_config
[
'max_value'
]
if
'max_value'
in
sd_config
else
config
[
'sigma_max'
]
return
partial
(
utils
.
rand_log_uniform
,
min_value
=
min_value
,
max_value
=
max_value
)
if
sd_config
[
'type'
]
==
'v-diffusion'
:
min_value
=
sd_config
[
'min_value'
]
if
'min_value'
in
sd_config
else
0.
max_value
=
sd_config
[
'max_value'
]
if
'max_value'
in
sd_config
else
float
(
'inf'
)
return
partial
(
utils
.
rand_v_diffusion
,
sigma_data
=
sigma_data
,
min_value
=
min_value
,
max_value
=
max_value
)
if
sd_config
[
'type'
]
==
'split-lognormal'
:
loc
=
sd_config
[
'mean'
]
if
'mean'
in
sd_config
else
sd_config
[
'loc'
]
scale_1
=
sd_config
[
'std_1'
]
if
'std_1'
in
sd_config
else
sd_config
[
'scale_1'
]
scale_2
=
sd_config
[
'std_2'
]
if
'std_2'
in
sd_config
else
sd_config
[
'scale_2'
]
return
partial
(
utils
.
rand_split_log_normal
,
loc
=
loc
,
scale_1
=
scale_1
,
scale_2
=
scale_2
)
raise
ValueError
(
'Unknown sample density type'
)
comfy/k_diffusion/evaluation.py
deleted
100644 → 0
View file @
274dff32
import
math
import
os
from
pathlib
import
Path
from
cleanfid.inception_torchscript
import
InceptionV3W
import
clip
from
resize_right
import
resize
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torchvision
import
transforms
from
tqdm.auto
import
trange
from
.
import
utils
class
InceptionV3FeatureExtractor
(
nn
.
Module
):
def
__init__
(
self
,
device
=
'cpu'
):
super
().
__init__
()
path
=
Path
(
os
.
environ
.
get
(
'XDG_CACHE_HOME'
,
Path
.
home
()
/
'.cache'
))
/
'k-diffusion'
url
=
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest
=
'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
utils
.
download_file
(
path
/
'inception-2015-12-05.pt'
,
url
,
digest
)
self
.
model
=
InceptionV3W
(
str
(
path
),
resize_inside
=
False
).
to
(
device
)
self
.
size
=
(
299
,
299
)
def
forward
(
self
,
x
):
if
x
.
shape
[
2
:
4
]
!=
self
.
size
:
x
=
resize
(
x
,
out_shape
=
self
.
size
,
pad_mode
=
'reflect'
)
if
x
.
shape
[
1
]
==
1
:
x
=
torch
.
cat
([
x
]
*
3
,
dim
=
1
)
x
=
(
x
*
127.5
+
127.5
).
clamp
(
0
,
255
)
return
self
.
model
(
x
)
class
CLIPFeatureExtractor
(
nn
.
Module
):
def
__init__
(
self
,
name
=
'ViT-L/14@336px'
,
device
=
'cpu'
):
super
().
__init__
()
self
.
model
=
clip
.
load
(
name
,
device
=
device
)[
0
].
eval
().
requires_grad_
(
False
)
self
.
normalize
=
transforms
.
Normalize
(
mean
=
(
0.48145466
,
0.4578275
,
0.40821073
),
std
=
(
0.26862954
,
0.26130258
,
0.27577711
))
self
.
size
=
(
self
.
model
.
visual
.
input_resolution
,
self
.
model
.
visual
.
input_resolution
)
def
forward
(
self
,
x
):
if
x
.
shape
[
2
:
4
]
!=
self
.
size
:
x
=
resize
(
x
.
add
(
1
).
div
(
2
),
out_shape
=
self
.
size
,
pad_mode
=
'reflect'
).
clamp
(
0
,
1
)
x
=
self
.
normalize
(
x
)
x
=
self
.
model
.
encode_image
(
x
).
float
()
x
=
F
.
normalize
(
x
)
*
x
.
shape
[
1
]
**
0.5
return
x
def
compute_features
(
accelerator
,
sample_fn
,
extractor_fn
,
n
,
batch_size
):
n_per_proc
=
math
.
ceil
(
n
/
accelerator
.
num_processes
)
feats_all
=
[]
try
:
for
i
in
trange
(
0
,
n_per_proc
,
batch_size
,
disable
=
not
accelerator
.
is_main_process
):
cur_batch_size
=
min
(
n
-
i
,
batch_size
)
samples
=
sample_fn
(
cur_batch_size
)[:
cur_batch_size
]
feats_all
.
append
(
accelerator
.
gather
(
extractor_fn
(
samples
)))
except
StopIteration
:
pass
return
torch
.
cat
(
feats_all
)[:
n
]
def
polynomial_kernel
(
x
,
y
):
d
=
x
.
shape
[
-
1
]
dot
=
x
@
y
.
transpose
(
-
2
,
-
1
)
return
(
dot
/
d
+
1
)
**
3
def
squared_mmd
(
x
,
y
,
kernel
=
polynomial_kernel
):
m
=
x
.
shape
[
-
2
]
n
=
y
.
shape
[
-
2
]
kxx
=
kernel
(
x
,
x
)
kyy
=
kernel
(
y
,
y
)
kxy
=
kernel
(
x
,
y
)
kxx_sum
=
kxx
.
sum
([
-
1
,
-
2
])
-
kxx
.
diagonal
(
dim1
=-
1
,
dim2
=-
2
).
sum
(
-
1
)
kyy_sum
=
kyy
.
sum
([
-
1
,
-
2
])
-
kyy
.
diagonal
(
dim1
=-
1
,
dim2
=-
2
).
sum
(
-
1
)
kxy_sum
=
kxy
.
sum
([
-
1
,
-
2
])
term_1
=
kxx_sum
/
m
/
(
m
-
1
)
term_2
=
kyy_sum
/
n
/
(
n
-
1
)
term_3
=
kxy_sum
*
2
/
m
/
n
return
term_1
+
term_2
-
term_3
@
utils
.
tf32_mode
(
matmul
=
False
)
def
kid
(
x
,
y
,
max_size
=
5000
):
x_size
,
y_size
=
x
.
shape
[
0
],
y
.
shape
[
0
]
n_partitions
=
math
.
ceil
(
max
(
x_size
/
max_size
,
y_size
/
max_size
))
total_mmd
=
x
.
new_zeros
([])
for
i
in
range
(
n_partitions
):
cur_x
=
x
[
round
(
i
*
x_size
/
n_partitions
):
round
((
i
+
1
)
*
x_size
/
n_partitions
)]
cur_y
=
y
[
round
(
i
*
y_size
/
n_partitions
):
round
((
i
+
1
)
*
y_size
/
n_partitions
)]
total_mmd
=
total_mmd
+
squared_mmd
(
cur_x
,
cur_y
)
return
total_mmd
/
n_partitions
class
_MatrixSquareRootEig
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
):
vals
,
vecs
=
torch
.
linalg
.
eigh
(
a
)
ctx
.
save_for_backward
(
vals
,
vecs
)
return
vecs
@
vals
.
abs
().
sqrt
().
diag_embed
()
@
vecs
.
transpose
(
-
2
,
-
1
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
vals
,
vecs
=
ctx
.
saved_tensors
d
=
vals
.
abs
().
sqrt
().
unsqueeze
(
-
1
).
repeat_interleave
(
vals
.
shape
[
-
1
],
-
1
)
vecs_t
=
vecs
.
transpose
(
-
2
,
-
1
)
return
vecs
@
(
vecs_t
@
grad_output
@
vecs
/
(
d
+
d
.
transpose
(
-
2
,
-
1
)))
@
vecs_t
def
sqrtm_eig
(
a
):
if
a
.
ndim
<
2
:
raise
RuntimeError
(
'tensor of matrices must have at least 2 dimensions'
)
if
a
.
shape
[
-
2
]
!=
a
.
shape
[
-
1
]:
raise
RuntimeError
(
'tensor must be batches of square matrices'
)
return
_MatrixSquareRootEig
.
apply
(
a
)
@
utils
.
tf32_mode
(
matmul
=
False
)
def
fid
(
x
,
y
,
eps
=
1e-8
):
x_mean
=
x
.
mean
(
dim
=
0
)
y_mean
=
y
.
mean
(
dim
=
0
)
mean_term
=
(
x_mean
-
y_mean
).
pow
(
2
).
sum
()
x_cov
=
torch
.
cov
(
x
.
T
)
y_cov
=
torch
.
cov
(
y
.
T
)
eps_eye
=
torch
.
eye
(
x_cov
.
shape
[
0
],
device
=
x_cov
.
device
,
dtype
=
x_cov
.
dtype
)
*
eps
x_cov
=
x_cov
+
eps_eye
y_cov
=
y_cov
+
eps_eye
x_cov_sqrt
=
sqrtm_eig
(
x_cov
)
cov_term
=
torch
.
trace
(
x_cov
+
y_cov
-
2
*
sqrtm_eig
(
x_cov_sqrt
@
y_cov
@
x_cov_sqrt
))
return
mean_term
+
cov_term
comfy/k_diffusion/gns.py
deleted
100644 → 0
View file @
274dff32
import
torch
from
torch
import
nn
class
DDPGradientStatsHook
:
def
__init__
(
self
,
ddp_module
):
try
:
ddp_module
.
register_comm_hook
(
self
,
self
.
_hook_fn
)
except
AttributeError
:
raise
ValueError
(
'DDPGradientStatsHook does not support non-DDP wrapped modules'
)
self
.
_clear_state
()
def
_clear_state
(
self
):
self
.
bucket_sq_norms_small_batch
=
[]
self
.
bucket_sq_norms_large_batch
=
[]
@
staticmethod
def
_hook_fn
(
self
,
bucket
):
buf
=
bucket
.
buffer
()
self
.
bucket_sq_norms_small_batch
.
append
(
buf
.
pow
(
2
).
sum
())
fut
=
torch
.
distributed
.
all_reduce
(
buf
,
op
=
torch
.
distributed
.
ReduceOp
.
AVG
,
async_op
=
True
).
get_future
()
def
callback
(
fut
):
buf
=
fut
.
value
()[
0
]
self
.
bucket_sq_norms_large_batch
.
append
(
buf
.
pow
(
2
).
sum
())
return
buf
return
fut
.
then
(
callback
)
def
get_stats
(
self
):
sq_norm_small_batch
=
sum
(
self
.
bucket_sq_norms_small_batch
)
sq_norm_large_batch
=
sum
(
self
.
bucket_sq_norms_large_batch
)
self
.
_clear_state
()
stats
=
torch
.
stack
([
sq_norm_small_batch
,
sq_norm_large_batch
])
torch
.
distributed
.
all_reduce
(
stats
,
op
=
torch
.
distributed
.
ReduceOp
.
AVG
)
return
stats
[
0
].
item
(),
stats
[
1
].
item
()
class
GradientNoiseScale
:
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
from _An Empirical Model of Large-Batch Training_,
https://arxiv.org/abs/1812.06162).
Args:
beta (float): The decay factor for the exponential moving averages used to
calculate the gradient noise scale.
Default: 0.9998
eps (float): Added for numerical stability.
Default: 1e-8
"""
def
__init__
(
self
,
beta
=
0.9998
,
eps
=
1e-8
):
self
.
beta
=
beta
self
.
eps
=
eps
self
.
ema_sq_norm
=
0.
self
.
ema_var
=
0.
self
.
beta_cumprod
=
1.
self
.
gradient_noise_scale
=
float
(
'nan'
)
def
state_dict
(
self
):
"""Returns the state of the object as a :class:`dict`."""
return
dict
(
self
.
__dict__
.
items
())
def
load_state_dict
(
self
,
state_dict
):
"""Loads the object's state.
Args:
state_dict (dict): object state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self
.
__dict__
.
update
(
state_dict
)
def
update
(
self
,
sq_norm_small_batch
,
sq_norm_large_batch
,
n_small_batch
,
n_large_batch
):
"""Updates the state with a new batch's gradient statistics, and returns the
current gradient noise scale.
Args:
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
per sample gradients.
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
per sample gradients.
n_small_batch (int): The batch size of the individual microbatch or per sample
gradients (1 if per sample).
n_large_batch (int): The total batch size of the mean of the microbatch or
per sample gradients.
"""
est_sq_norm
=
(
n_large_batch
*
sq_norm_large_batch
-
n_small_batch
*
sq_norm_small_batch
)
/
(
n_large_batch
-
n_small_batch
)
est_var
=
(
sq_norm_small_batch
-
sq_norm_large_batch
)
/
(
1
/
n_small_batch
-
1
/
n_large_batch
)
self
.
ema_sq_norm
=
self
.
beta
*
self
.
ema_sq_norm
+
(
1
-
self
.
beta
)
*
est_sq_norm
self
.
ema_var
=
self
.
beta
*
self
.
ema_var
+
(
1
-
self
.
beta
)
*
est_var
self
.
beta_cumprod
*=
self
.
beta
self
.
gradient_noise_scale
=
max
(
self
.
ema_var
,
self
.
eps
)
/
max
(
self
.
ema_sq_norm
,
self
.
eps
)
return
self
.
gradient_noise_scale
def
get_gns
(
self
):
"""Returns the current gradient noise scale."""
return
self
.
gradient_noise_scale
def
get_stats
(
self
):
"""Returns the current (debiased) estimates of the squared mean gradient
and gradient variance."""
return
self
.
ema_sq_norm
/
(
1
-
self
.
beta_cumprod
),
self
.
ema_var
/
(
1
-
self
.
beta_cumprod
)
comfy/k_diffusion/layers.py
deleted
100644 → 0
View file @
274dff32
import
math
from
einops
import
rearrange
,
repeat
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
.
import
utils
# Karras et al. preconditioned denoiser
class
Denoiser
(
nn
.
Module
):
"""A Karras et al. preconditioner for denoising diffusion models."""
def
__init__
(
self
,
inner_model
,
sigma_data
=
1.
):
super
().
__init__
()
self
.
inner_model
=
inner_model
self
.
sigma_data
=
sigma_data
def
get_scalings
(
self
,
sigma
):
c_skip
=
self
.
sigma_data
**
2
/
(
sigma
**
2
+
self
.
sigma_data
**
2
)
c_out
=
sigma
*
self
.
sigma_data
/
(
sigma
**
2
+
self
.
sigma_data
**
2
)
**
0.5
c_in
=
1
/
(
sigma
**
2
+
self
.
sigma_data
**
2
)
**
0.5
return
c_skip
,
c_out
,
c_in
def
loss
(
self
,
input
,
noise
,
sigma
,
**
kwargs
):
c_skip
,
c_out
,
c_in
=
[
utils
.
append_dims
(
x
,
input
.
ndim
)
for
x
in
self
.
get_scalings
(
sigma
)]
noised_input
=
input
+
noise
*
utils
.
append_dims
(
sigma
,
input
.
ndim
)
model_output
=
self
.
inner_model
(
noised_input
*
c_in
,
sigma
,
**
kwargs
)
target
=
(
input
-
c_skip
*
noised_input
)
/
c_out
return
(
model_output
-
target
).
pow
(
2
).
flatten
(
1
).
mean
(
1
)
def
forward
(
self
,
input
,
sigma
,
**
kwargs
):
c_skip
,
c_out
,
c_in
=
[
utils
.
append_dims
(
x
,
input
.
ndim
)
for
x
in
self
.
get_scalings
(
sigma
)]
return
self
.
inner_model
(
input
*
c_in
,
sigma
,
**
kwargs
)
*
c_out
+
input
*
c_skip
class
DenoiserWithVariance
(
Denoiser
):
def
loss
(
self
,
input
,
noise
,
sigma
,
**
kwargs
):
c_skip
,
c_out
,
c_in
=
[
utils
.
append_dims
(
x
,
input
.
ndim
)
for
x
in
self
.
get_scalings
(
sigma
)]
noised_input
=
input
+
noise
*
utils
.
append_dims
(
sigma
,
input
.
ndim
)
model_output
,
logvar
=
self
.
inner_model
(
noised_input
*
c_in
,
sigma
,
return_variance
=
True
,
**
kwargs
)
logvar
=
utils
.
append_dims
(
logvar
,
model_output
.
ndim
)
target
=
(
input
-
c_skip
*
noised_input
)
/
c_out
losses
=
((
model_output
-
target
)
**
2
/
logvar
.
exp
()
+
logvar
)
/
2
return
losses
.
flatten
(
1
).
mean
(
1
)
# Residual blocks
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
main
,
skip
=
None
):
super
().
__init__
()
self
.
main
=
nn
.
Sequential
(
*
main
)
self
.
skip
=
skip
if
skip
else
nn
.
Identity
()
def
forward
(
self
,
input
):
return
self
.
main
(
input
)
+
self
.
skip
(
input
)
# Noise level (and other) conditioning
class
ConditionedModule
(
nn
.
Module
):
pass
class
UnconditionedModule
(
ConditionedModule
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
input
,
cond
=
None
):
return
self
.
module
(
input
)
class
ConditionedSequential
(
nn
.
Sequential
,
ConditionedModule
):
def
forward
(
self
,
input
,
cond
):
for
module
in
self
:
if
isinstance
(
module
,
ConditionedModule
):
input
=
module
(
input
,
cond
)
else
:
input
=
module
(
input
)
return
input
class
ConditionedResidualBlock
(
ConditionedModule
):
def
__init__
(
self
,
*
main
,
skip
=
None
):
super
().
__init__
()
self
.
main
=
ConditionedSequential
(
*
main
)
self
.
skip
=
skip
if
skip
else
nn
.
Identity
()
def
forward
(
self
,
input
,
cond
):
skip
=
self
.
skip
(
input
,
cond
)
if
isinstance
(
self
.
skip
,
ConditionedModule
)
else
self
.
skip
(
input
)
return
self
.
main
(
input
,
cond
)
+
skip
class
AdaGN
(
ConditionedModule
):
def
__init__
(
self
,
feats_in
,
c_out
,
num_groups
,
eps
=
1e-5
,
cond_key
=
'cond'
):
super
().
__init__
()
self
.
num_groups
=
num_groups
self
.
eps
=
eps
self
.
cond_key
=
cond_key
self
.
mapper
=
nn
.
Linear
(
feats_in
,
c_out
*
2
)
def
forward
(
self
,
input
,
cond
):
weight
,
bias
=
self
.
mapper
(
cond
[
self
.
cond_key
]).
chunk
(
2
,
dim
=-
1
)
input
=
F
.
group_norm
(
input
,
self
.
num_groups
,
eps
=
self
.
eps
)
return
torch
.
addcmul
(
utils
.
append_dims
(
bias
,
input
.
ndim
),
input
,
utils
.
append_dims
(
weight
,
input
.
ndim
)
+
1
)
# Attention
class
SelfAttention2d
(
ConditionedModule
):
def
__init__
(
self
,
c_in
,
n_head
,
norm
,
dropout_rate
=
0.
):
super
().
__init__
()
assert
c_in
%
n_head
==
0
self
.
norm_in
=
norm
(
c_in
)
self
.
n_head
=
n_head
self
.
qkv_proj
=
nn
.
Conv2d
(
c_in
,
c_in
*
3
,
1
)
self
.
out_proj
=
nn
.
Conv2d
(
c_in
,
c_in
,
1
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
def
forward
(
self
,
input
,
cond
):
n
,
c
,
h
,
w
=
input
.
shape
qkv
=
self
.
qkv_proj
(
self
.
norm_in
(
input
,
cond
))
qkv
=
qkv
.
view
([
n
,
self
.
n_head
*
3
,
c
//
self
.
n_head
,
h
*
w
]).
transpose
(
2
,
3
)
q
,
k
,
v
=
qkv
.
chunk
(
3
,
dim
=
1
)
scale
=
k
.
shape
[
3
]
**
-
0.25
att
=
((
q
*
scale
)
@
(
k
.
transpose
(
2
,
3
)
*
scale
)).
softmax
(
3
)
att
=
self
.
dropout
(
att
)
y
=
(
att
@
v
).
transpose
(
2
,
3
).
contiguous
().
view
([
n
,
c
,
h
,
w
])
return
input
+
self
.
out_proj
(
y
)
class
CrossAttention2d
(
ConditionedModule
):
def
__init__
(
self
,
c_dec
,
c_enc
,
n_head
,
norm_dec
,
dropout_rate
=
0.
,
cond_key
=
'cross'
,
cond_key_padding
=
'cross_padding'
):
super
().
__init__
()
assert
c_dec
%
n_head
==
0
self
.
cond_key
=
cond_key
self
.
cond_key_padding
=
cond_key_padding
self
.
norm_enc
=
nn
.
LayerNorm
(
c_enc
)
self
.
norm_dec
=
norm_dec
(
c_dec
)
self
.
n_head
=
n_head
self
.
q_proj
=
nn
.
Conv2d
(
c_dec
,
c_dec
,
1
)
self
.
kv_proj
=
nn
.
Linear
(
c_enc
,
c_dec
*
2
)
self
.
out_proj
=
nn
.
Conv2d
(
c_dec
,
c_dec
,
1
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
def
forward
(
self
,
input
,
cond
):
n
,
c
,
h
,
w
=
input
.
shape
q
=
self
.
q_proj
(
self
.
norm_dec
(
input
,
cond
))
q
=
q
.
view
([
n
,
self
.
n_head
,
c
//
self
.
n_head
,
h
*
w
]).
transpose
(
2
,
3
)
kv
=
self
.
kv_proj
(
self
.
norm_enc
(
cond
[
self
.
cond_key
]))
kv
=
kv
.
view
([
n
,
-
1
,
self
.
n_head
*
2
,
c
//
self
.
n_head
]).
transpose
(
1
,
2
)
k
,
v
=
kv
.
chunk
(
2
,
dim
=
1
)
scale
=
k
.
shape
[
3
]
**
-
0.25
att
=
((
q
*
scale
)
@
(
k
.
transpose
(
2
,
3
)
*
scale
))
att
=
att
-
(
cond
[
self
.
cond_key_padding
][:,
None
,
None
,
:])
*
10000
att
=
att
.
softmax
(
3
)
att
=
self
.
dropout
(
att
)
y
=
(
att
@
v
).
transpose
(
2
,
3
)
y
=
y
.
contiguous
().
view
([
n
,
c
,
h
,
w
])
return
input
+
self
.
out_proj
(
y
)
# Downsampling/upsampling
_kernels
=
{
'linear'
:
[
1
/
8
,
3
/
8
,
3
/
8
,
1
/
8
],
'cubic'
:
[
-
0.01171875
,
-
0.03515625
,
0.11328125
,
0.43359375
,
0.43359375
,
0.11328125
,
-
0.03515625
,
-
0.01171875
],
'lanczos3'
:
[
0.003689131001010537
,
0.015056144446134567
,
-
0.03399861603975296
,
-
0.066637322306633
,
0.13550527393817902
,
0.44638532400131226
,
0.44638532400131226
,
0.13550527393817902
,
-
0.066637322306633
,
-
0.03399861603975296
,
0.015056144446134567
,
0.003689131001010537
]
}
_kernels
[
'bilinear'
]
=
_kernels
[
'linear'
]
_kernels
[
'bicubic'
]
=
_kernels
[
'cubic'
]
class
Downsample2d
(
nn
.
Module
):
def
__init__
(
self
,
kernel
=
'linear'
,
pad_mode
=
'reflect'
):
super
().
__init__
()
self
.
pad_mode
=
pad_mode
kernel_1d
=
torch
.
tensor
([
_kernels
[
kernel
]])
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
register_buffer
(
'kernel'
,
kernel_1d
.
T
@
kernel_1d
)
def
forward
(
self
,
x
):
x
=
F
.
pad
(
x
,
(
self
.
pad
,)
*
4
,
self
.
pad_mode
)
weight
=
x
.
new_zeros
([
x
.
shape
[
1
],
x
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
indices
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
weight
[
indices
,
indices
]
=
self
.
kernel
.
to
(
weight
)
return
F
.
conv2d
(
x
,
weight
,
stride
=
2
)
class
Upsample2d
(
nn
.
Module
):
def
__init__
(
self
,
kernel
=
'linear'
,
pad_mode
=
'reflect'
):
super
().
__init__
()
self
.
pad_mode
=
pad_mode
kernel_1d
=
torch
.
tensor
([
_kernels
[
kernel
]])
*
2
self
.
pad
=
kernel_1d
.
shape
[
1
]
//
2
-
1
self
.
register_buffer
(
'kernel'
,
kernel_1d
.
T
@
kernel_1d
)
def
forward
(
self
,
x
):
x
=
F
.
pad
(
x
,
((
self
.
pad
+
1
)
//
2
,)
*
4
,
self
.
pad_mode
)
weight
=
x
.
new_zeros
([
x
.
shape
[
1
],
x
.
shape
[
1
],
self
.
kernel
.
shape
[
0
],
self
.
kernel
.
shape
[
1
]])
indices
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
weight
[
indices
,
indices
]
=
self
.
kernel
.
to
(
weight
)
return
F
.
conv_transpose2d
(
x
,
weight
,
stride
=
2
,
padding
=
self
.
pad
*
2
+
1
)
# Embeddings
class
FourierFeatures
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
std
=
1.
):
super
().
__init__
()
assert
out_features
%
2
==
0
self
.
register_buffer
(
'weight'
,
torch
.
randn
([
out_features
//
2
,
in_features
])
*
std
)
def
forward
(
self
,
input
):
f
=
2
*
math
.
pi
*
input
@
self
.
weight
.
T
return
torch
.
cat
([
f
.
cos
(),
f
.
sin
()],
dim
=-
1
)
# U-Nets
class
UNet
(
ConditionedModule
):
def
__init__
(
self
,
d_blocks
,
u_blocks
,
skip_stages
=
0
):
super
().
__init__
()
self
.
d_blocks
=
nn
.
ModuleList
(
d_blocks
)
self
.
u_blocks
=
nn
.
ModuleList
(
u_blocks
)
self
.
skip_stages
=
skip_stages
def
forward
(
self
,
input
,
cond
):
skips
=
[]
for
block
in
self
.
d_blocks
[
self
.
skip_stages
:]:
input
=
block
(
input
,
cond
)
skips
.
append
(
input
)
for
i
,
(
block
,
skip
)
in
enumerate
(
zip
(
self
.
u_blocks
,
reversed
(
skips
))):
input
=
block
(
input
,
cond
,
skip
if
i
>
0
else
None
)
return
input
comfy/k_diffusion/models/__init__.py
deleted
100644 → 0
View file @
274dff32
from
.image_v1
import
ImageDenoiserModelV1
comfy/k_diffusion/models/image_v1.py
deleted
100644 → 0
View file @
274dff32
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
..
import
layers
,
utils
def
orthogonal_
(
module
):
nn
.
init
.
orthogonal_
(
module
.
weight
)
return
module
class
ResConvBlock
(
layers
.
ConditionedResidualBlock
):
def
__init__
(
self
,
feats_in
,
c_in
,
c_mid
,
c_out
,
group_size
=
32
,
dropout_rate
=
0.
):
skip
=
None
if
c_in
==
c_out
else
orthogonal_
(
nn
.
Conv2d
(
c_in
,
c_out
,
1
,
bias
=
False
))
super
().
__init__
(
layers
.
AdaGN
(
feats_in
,
c_in
,
max
(
1
,
c_in
//
group_size
)),
nn
.
GELU
(),
nn
.
Conv2d
(
c_in
,
c_mid
,
3
,
padding
=
1
),
nn
.
Dropout2d
(
dropout_rate
,
inplace
=
True
),
layers
.
AdaGN
(
feats_in
,
c_mid
,
max
(
1
,
c_mid
//
group_size
)),
nn
.
GELU
(),
nn
.
Conv2d
(
c_mid
,
c_out
,
3
,
padding
=
1
),
nn
.
Dropout2d
(
dropout_rate
,
inplace
=
True
),
skip
=
skip
)
class
DBlock
(
layers
.
ConditionedSequential
):
def
__init__
(
self
,
n_layers
,
feats_in
,
c_in
,
c_mid
,
c_out
,
group_size
=
32
,
head_size
=
64
,
dropout_rate
=
0.
,
downsample
=
False
,
self_attn
=
False
,
cross_attn
=
False
,
c_enc
=
0
):
modules
=
[
nn
.
Identity
()]
for
i
in
range
(
n_layers
):
my_c_in
=
c_in
if
i
==
0
else
c_mid
my_c_out
=
c_mid
if
i
<
n_layers
-
1
else
c_out
modules
.
append
(
ResConvBlock
(
feats_in
,
my_c_in
,
c_mid
,
my_c_out
,
group_size
,
dropout_rate
))
if
self_attn
:
norm
=
lambda
c_in
:
layers
.
AdaGN
(
feats_in
,
c_in
,
max
(
1
,
my_c_out
//
group_size
))
modules
.
append
(
layers
.
SelfAttention2d
(
my_c_out
,
max
(
1
,
my_c_out
//
head_size
),
norm
,
dropout_rate
))
if
cross_attn
:
norm
=
lambda
c_in
:
layers
.
AdaGN
(
feats_in
,
c_in
,
max
(
1
,
my_c_out
//
group_size
))
modules
.
append
(
layers
.
CrossAttention2d
(
my_c_out
,
c_enc
,
max
(
1
,
my_c_out
//
head_size
),
norm
,
dropout_rate
))
super
().
__init__
(
*
modules
)
self
.
set_downsample
(
downsample
)
def
set_downsample
(
self
,
downsample
):
self
[
0
]
=
layers
.
Downsample2d
()
if
downsample
else
nn
.
Identity
()
return
self
class
UBlock
(
layers
.
ConditionedSequential
):
def
__init__
(
self
,
n_layers
,
feats_in
,
c_in
,
c_mid
,
c_out
,
group_size
=
32
,
head_size
=
64
,
dropout_rate
=
0.
,
upsample
=
False
,
self_attn
=
False
,
cross_attn
=
False
,
c_enc
=
0
):
modules
=
[]
for
i
in
range
(
n_layers
):
my_c_in
=
c_in
if
i
==
0
else
c_mid
my_c_out
=
c_mid
if
i
<
n_layers
-
1
else
c_out
modules
.
append
(
ResConvBlock
(
feats_in
,
my_c_in
,
c_mid
,
my_c_out
,
group_size
,
dropout_rate
))
if
self_attn
:
norm
=
lambda
c_in
:
layers
.
AdaGN
(
feats_in
,
c_in
,
max
(
1
,
my_c_out
//
group_size
))
modules
.
append
(
layers
.
SelfAttention2d
(
my_c_out
,
max
(
1
,
my_c_out
//
head_size
),
norm
,
dropout_rate
))
if
cross_attn
:
norm
=
lambda
c_in
:
layers
.
AdaGN
(
feats_in
,
c_in
,
max
(
1
,
my_c_out
//
group_size
))
modules
.
append
(
layers
.
CrossAttention2d
(
my_c_out
,
c_enc
,
max
(
1
,
my_c_out
//
head_size
),
norm
,
dropout_rate
))
modules
.
append
(
nn
.
Identity
())
super
().
__init__
(
*
modules
)
self
.
set_upsample
(
upsample
)
def
forward
(
self
,
input
,
cond
,
skip
=
None
):
if
skip
is
not
None
:
input
=
torch
.
cat
([
input
,
skip
],
dim
=
1
)
return
super
().
forward
(
input
,
cond
)
def
set_upsample
(
self
,
upsample
):
self
[
-
1
]
=
layers
.
Upsample2d
()
if
upsample
else
nn
.
Identity
()
return
self
class
MappingNet
(
nn
.
Sequential
):
def
__init__
(
self
,
feats_in
,
feats_out
,
n_layers
=
2
):
layers
=
[]
for
i
in
range
(
n_layers
):
layers
.
append
(
orthogonal_
(
nn
.
Linear
(
feats_in
if
i
==
0
else
feats_out
,
feats_out
)))
layers
.
append
(
nn
.
GELU
())
super
().
__init__
(
*
layers
)
class
ImageDenoiserModelV1
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
feats_in
,
depths
,
channels
,
self_attn_depths
,
cross_attn_depths
=
None
,
mapping_cond_dim
=
0
,
unet_cond_dim
=
0
,
cross_cond_dim
=
0
,
dropout_rate
=
0.
,
patch_size
=
1
,
skip_stages
=
0
,
has_variance
=
False
):
super
().
__init__
()
self
.
c_in
=
c_in
self
.
channels
=
channels
self
.
unet_cond_dim
=
unet_cond_dim
self
.
patch_size
=
patch_size
self
.
has_variance
=
has_variance
self
.
timestep_embed
=
layers
.
FourierFeatures
(
1
,
feats_in
)
if
mapping_cond_dim
>
0
:
self
.
mapping_cond
=
nn
.
Linear
(
mapping_cond_dim
,
feats_in
,
bias
=
False
)
self
.
mapping
=
MappingNet
(
feats_in
,
feats_in
)
self
.
proj_in
=
nn
.
Conv2d
((
c_in
+
unet_cond_dim
)
*
self
.
patch_size
**
2
,
channels
[
max
(
0
,
skip_stages
-
1
)],
1
)
self
.
proj_out
=
nn
.
Conv2d
(
channels
[
max
(
0
,
skip_stages
-
1
)],
c_in
*
self
.
patch_size
**
2
+
(
1
if
self
.
has_variance
else
0
),
1
)
nn
.
init
.
zeros_
(
self
.
proj_out
.
weight
)
nn
.
init
.
zeros_
(
self
.
proj_out
.
bias
)
if
cross_cond_dim
==
0
:
cross_attn_depths
=
[
False
]
*
len
(
self_attn_depths
)
d_blocks
,
u_blocks
=
[],
[]
for
i
in
range
(
len
(
depths
)):
my_c_in
=
channels
[
max
(
0
,
i
-
1
)]
d_blocks
.
append
(
DBlock
(
depths
[
i
],
feats_in
,
my_c_in
,
channels
[
i
],
channels
[
i
],
downsample
=
i
>
skip_stages
,
self_attn
=
self_attn_depths
[
i
],
cross_attn
=
cross_attn_depths
[
i
],
c_enc
=
cross_cond_dim
,
dropout_rate
=
dropout_rate
))
for
i
in
range
(
len
(
depths
)):
my_c_in
=
channels
[
i
]
*
2
if
i
<
len
(
depths
)
-
1
else
channels
[
i
]
my_c_out
=
channels
[
max
(
0
,
i
-
1
)]
u_blocks
.
append
(
UBlock
(
depths
[
i
],
feats_in
,
my_c_in
,
channels
[
i
],
my_c_out
,
upsample
=
i
>
skip_stages
,
self_attn
=
self_attn_depths
[
i
],
cross_attn
=
cross_attn_depths
[
i
],
c_enc
=
cross_cond_dim
,
dropout_rate
=
dropout_rate
))
self
.
u_net
=
layers
.
UNet
(
d_blocks
,
reversed
(
u_blocks
),
skip_stages
=
skip_stages
)
def
forward
(
self
,
input
,
sigma
,
mapping_cond
=
None
,
unet_cond
=
None
,
cross_cond
=
None
,
cross_cond_padding
=
None
,
return_variance
=
False
):
c_noise
=
sigma
.
log
()
/
4
timestep_embed
=
self
.
timestep_embed
(
utils
.
append_dims
(
c_noise
,
2
))
mapping_cond_embed
=
torch
.
zeros_like
(
timestep_embed
)
if
mapping_cond
is
None
else
self
.
mapping_cond
(
mapping_cond
)
mapping_out
=
self
.
mapping
(
timestep_embed
+
mapping_cond_embed
)
cond
=
{
'cond'
:
mapping_out
}
if
unet_cond
is
not
None
:
input
=
torch
.
cat
([
input
,
unet_cond
],
dim
=
1
)
if
cross_cond
is
not
None
:
cond
[
'cross'
]
=
cross_cond
cond
[
'cross_padding'
]
=
cross_cond_padding
if
self
.
patch_size
>
1
:
input
=
F
.
pixel_unshuffle
(
input
,
self
.
patch_size
)
input
=
self
.
proj_in
(
input
)
input
=
self
.
u_net
(
input
,
cond
)
input
=
self
.
proj_out
(
input
)
if
self
.
has_variance
:
input
,
logvar
=
input
[:,
:
-
1
],
input
[:,
-
1
].
flatten
(
1
).
mean
(
1
)
if
self
.
patch_size
>
1
:
input
=
F
.
pixel_shuffle
(
input
,
self
.
patch_size
)
if
self
.
has_variance
and
return_variance
:
return
input
,
logvar
return
input
def
set_skip_stages
(
self
,
skip_stages
):
self
.
proj_in
=
nn
.
Conv2d
(
self
.
proj_in
.
in_channels
,
self
.
channels
[
max
(
0
,
skip_stages
-
1
)],
1
)
self
.
proj_out
=
nn
.
Conv2d
(
self
.
channels
[
max
(
0
,
skip_stages
-
1
)],
self
.
proj_out
.
out_channels
,
1
)
nn
.
init
.
zeros_
(
self
.
proj_out
.
weight
)
nn
.
init
.
zeros_
(
self
.
proj_out
.
bias
)
self
.
u_net
.
skip_stages
=
skip_stages
for
i
,
block
in
enumerate
(
self
.
u_net
.
d_blocks
):
block
.
set_downsample
(
i
>
skip_stages
)
for
i
,
block
in
enumerate
(
reversed
(
self
.
u_net
.
u_blocks
)):
block
.
set_upsample
(
i
>
skip_stages
)
return
self
def
set_patch_size
(
self
,
patch_size
):
self
.
patch_size
=
patch_size
self
.
proj_in
=
nn
.
Conv2d
((
self
.
c_in
+
self
.
unet_cond_dim
)
*
self
.
patch_size
**
2
,
self
.
channels
[
max
(
0
,
self
.
u_net
.
skip_stages
-
1
)],
1
)
self
.
proj_out
=
nn
.
Conv2d
(
self
.
channels
[
max
(
0
,
self
.
u_net
.
skip_stages
-
1
)],
self
.
c_in
*
self
.
patch_size
**
2
+
(
1
if
self
.
has_variance
else
0
),
1
)
nn
.
init
.
zeros_
(
self
.
proj_out
.
weight
)
nn
.
init
.
zeros_
(
self
.
proj_out
.
bias
)
comfy/k_diffusion/utils.py
View file @
2b14041d
...
...
@@ -10,25 +10,6 @@ from PIL import Image
import
torch
from
torch
import
nn
,
optim
from
torch.utils
import
data
from
torchvision.transforms
import
functional
as
TF
def
from_pil_image
(
x
):
"""Converts from a PIL image to a tensor."""
x
=
TF
.
to_tensor
(
x
)
if
x
.
ndim
==
2
:
x
=
x
[...,
None
]
return
x
*
2
-
1
def
to_pil_image
(
x
):
"""Converts from a tensor to a PIL image."""
if
x
.
ndim
==
4
:
assert
x
.
shape
[
0
]
==
1
x
=
x
[
0
]
if
x
.
shape
[
0
]
==
1
:
x
=
x
[
0
]
return
TF
.
to_pil_image
((
x
.
clamp
(
-
1
,
1
)
+
1
)
/
2
)
def
hf_datasets_augs_helper
(
examples
,
transform
,
image_key
,
mode
=
'RGB'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment