Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a7e8159d
Commit
a7e8159d
authored
Nov 08, 2022
by
Maruyama_Aya
Browse files
add ColoDiffusion codes: /ldm/module/, /ldm/data/, /scripts/test/
parent
441d584e
Changes
30
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
3271 additions
and
0 deletions
+3271
-0
examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py
.../images/diffusion/ldm/modules/image_degradation/bsrgan.py
+730
-0
examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py
...s/diffusion/ldm/modules/image_degradation/bsrgan_light.py
+650
-0
examples/images/diffusion/ldm/modules/image_degradation/utils/test.png
...es/diffusion/ldm/modules/image_degradation/utils/test.png
+0
-0
examples/images/diffusion/ldm/modules/image_degradation/utils_image.py
...es/diffusion/ldm/modules/image_degradation/utils_image.py
+916
-0
examples/images/diffusion/ldm/modules/losses/__init__.py
examples/images/diffusion/ldm/modules/losses/__init__.py
+1
-0
examples/images/diffusion/ldm/modules/losses/contperceptual.py
...les/images/diffusion/ldm/modules/losses/contperceptual.py
+111
-0
examples/images/diffusion/ldm/modules/losses/vqperceptual.py
examples/images/diffusion/ldm/modules/losses/vqperceptual.py
+167
-0
examples/images/diffusion/ldm/modules/x_transformer.py
examples/images/diffusion/ldm/modules/x_transformer.py
+641
-0
examples/images/diffusion/scripts/tests/test_checkpoint.py
examples/images/diffusion/scripts/tests/test_checkpoint.py
+37
-0
examples/images/diffusion/scripts/tests/test_watermark.py
examples/images/diffusion/scripts/tests/test_watermark.py
+18
-0
No files found.
examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py
0 → 100644
View file @
a7e8159d
This diff is collapsed.
Click to expand it.
examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py
0 → 100644
View file @
a7e8159d
This diff is collapsed.
Click to expand it.
examples/images/diffusion/ldm/modules/image_degradation/utils/test.png
0 → 100644
View file @
a7e8159d
431 KB
examples/images/diffusion/ldm/modules/image_degradation/utils_image.py
0 → 100644
View file @
a7e8159d
This diff is collapsed.
Click to expand it.
examples/images/diffusion/ldm/modules/losses/__init__.py
0 → 100644
View file @
a7e8159d
from
ldm.modules.losses.contperceptual
import
LPIPSWithDiscriminator
\ No newline at end of file
examples/images/diffusion/ldm/modules/losses/contperceptual.py
0 → 100644
View file @
a7e8159d
import
torch
import
torch.nn
as
nn
from
taming.modules.losses.vqperceptual
import
*
# TODO: taming dependency yes/no?
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_loss
=
"hinge"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
weights
=
None
):
rec_loss
=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
if
self
.
disc_factor
>
0.0
:
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
examples/images/diffusion/ldm/modules/losses/vqperceptual.py
0 → 100644
View file @
a7e8159d
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
einops
import
repeat
from
taming.modules.discriminator.model
import
NLayerDiscriminator
,
weights_init
from
taming.modules.losses.lpips
import
LPIPS
from
taming.modules.losses.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
def
hinge_d_loss_with_exemplar_weights
(
logits_real
,
logits_fake
,
weights
):
assert
weights
.
shape
[
0
]
==
logits_real
.
shape
[
0
]
==
logits_fake
.
shape
[
0
]
loss_real
=
torch
.
mean
(
F
.
relu
(
1.
-
logits_real
),
dim
=
[
1
,
2
,
3
])
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.
+
logits_fake
),
dim
=
[
1
,
2
,
3
])
loss_real
=
(
weights
*
loss_real
).
sum
()
/
weights
.
sum
()
loss_fake
=
(
weights
*
loss_fake
).
sum
()
/
weights
.
sum
()
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.
):
if
global_step
<
threshold
:
weight
=
value
return
weight
def
measure_perplexity
(
predicted_indices
,
n_embed
):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings
=
F
.
one_hot
(
predicted_indices
,
n_embed
).
float
().
reshape
(
-
1
,
n_embed
)
avg_probs
=
encodings
.
mean
(
0
)
perplexity
=
(
-
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)).
sum
()).
exp
()
cluster_use
=
torch
.
sum
(
avg_probs
>
0
)
return
perplexity
,
cluster_use
def
l1
(
x
,
y
):
return
torch
.
abs
(
x
-
y
)
def
l2
(
x
,
y
):
return
torch
.
pow
((
x
-
y
),
2
)
class
VQLPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
codebook_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_ndf
=
64
,
disc_loss
=
"hinge"
,
n_classes
=
None
,
perceptual_loss
=
"lpips"
,
pixel_loss
=
"l1"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
assert
perceptual_loss
in
[
"lpips"
,
"clips"
,
"dists"
]
assert
pixel_loss
in
[
"l1"
,
"l2"
]
self
.
codebook_weight
=
codebook_weight
self
.
pixel_weight
=
pixelloss_weight
if
perceptual_loss
==
"lpips"
:
print
(
f
"
{
self
.
__class__
.
__name__
}
: Running with LPIPS."
)
self
.
perceptual_loss
=
LPIPS
().
eval
()
else
:
raise
ValueError
(
f
"Unknown perceptual loss: >>
{
perceptual_loss
}
<<"
)
self
.
perceptual_weight
=
perceptual_weight
if
pixel_loss
==
"l1"
:
self
.
pixel_loss
=
l1
else
:
self
.
pixel_loss
=
l2
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
,
ndf
=
disc_ndf
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
if
disc_loss
==
"hinge"
:
self
.
disc_loss
=
hinge_d_loss
elif
disc_loss
==
"vanilla"
:
self
.
disc_loss
=
vanilla_d_loss
else
:
raise
ValueError
(
f
"Unknown GAN loss '
{
disc_loss
}
'."
)
print
(
f
"VQLPIPSWithDiscriminator running with
{
disc_loss
}
loss."
)
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
n_classes
=
n_classes
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
codebook_loss
,
inputs
,
reconstructions
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
predicted_indices
=
None
):
if
not
exists
(
codebook_loss
):
codebook_loss
=
torch
.
tensor
([
0.
]).
to
(
inputs
.
device
)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss
=
self
.
pixel_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
else
:
p_loss
=
torch
.
tensor
([
0.0
])
nll_loss
=
rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss
=
torch
.
mean
(
nll_loss
)
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
nll_loss
+
d_weight
*
disc_factor
*
g_loss
+
self
.
codebook_weight
*
codebook_loss
.
mean
()
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/quant_loss"
.
format
(
split
):
codebook_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/p_loss"
.
format
(
split
):
p_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
if
predicted_indices
is
not
None
:
assert
self
.
n_classes
is
not
None
with
torch
.
no_grad
():
perplexity
,
cluster_usage
=
measure_perplexity
(
predicted_indices
,
self
.
n_classes
)
log
[
f
"
{
split
}
/perplexity"
]
=
perplexity
log
[
f
"
{
split
}
/cluster_usage"
]
=
cluster_usage
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
examples/images/diffusion/ldm/modules/x_transformer.py
0 → 100644
View file @
a7e8159d
This diff is collapsed.
Click to expand it.
examples/images/diffusion/scripts/tests/test_checkpoint.py
0 → 100644
View file @
a7e8159d
import
os
import
sys
from
copy
import
deepcopy
import
yaml
from
datetime
import
datetime
from
diffusers
import
StableDiffusionPipeline
import
torch
from
ldm.util
import
instantiate_from_config
from
main
import
get_parser
if
__name__
==
"__main__"
:
with
torch
.
no_grad
():
yaml_path
=
"../../train_colossalai.yaml"
with
open
(
yaml_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config
=
f
.
read
()
base_config
=
yaml
.
load
(
config
,
Loader
=
yaml
.
FullLoader
)
unet_config
=
base_config
[
'model'
][
'params'
][
'unet_config'
]
diffusion_model
=
instantiate_from_config
(
unet_config
).
to
(
"cuda:0"
)
pipe
=
StableDiffusionPipeline
.
from_pretrained
(
"/data/scratch/diffuser/stable-diffusion-v1-4"
).
to
(
"cuda:0"
)
dif_model_2
=
pipe
.
unet
random_input_
=
torch
.
rand
((
4
,
4
,
32
,
32
)).
to
(
"cuda:0"
)
random_input_2
=
torch
.
clone
(
random_input_
).
to
(
"cuda:0"
)
time_stamp
=
torch
.
randint
(
20
,
(
4
,)).
to
(
"cuda:0"
)
time_stamp2
=
torch
.
clone
(
time_stamp
).
to
(
"cuda:0"
)
context_
=
torch
.
rand
((
4
,
77
,
768
)).
to
(
"cuda:0"
)
context_2
=
torch
.
clone
(
context_
).
to
(
"cuda:0"
)
out_1
=
diffusion_model
(
random_input_
,
time_stamp
,
context_
)
out_2
=
dif_model_2
(
random_input_2
,
time_stamp2
,
context_2
)
print
(
out_1
.
shape
)
print
(
out_2
[
'sample'
].
shape
)
\ No newline at end of file
examples/images/diffusion/scripts/tests/test_watermark.py
0 → 100644
View file @
a7e8159d
import
cv2
import
fire
from
imwatermark
import
WatermarkDecoder
def
testit
(
img_path
):
bgr
=
cv2
.
imread
(
img_path
)
decoder
=
WatermarkDecoder
(
'bytes'
,
136
)
watermark
=
decoder
.
decode
(
bgr
,
'dwtDct'
)
try
:
dec
=
watermark
.
decode
(
'utf-8'
)
except
:
dec
=
"null"
print
(
dec
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
testit
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment