Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
stable-diffusion_pytorch
Commits
86685e45
Commit
86685e45
authored
Aug 22, 2023
by
lijian6
Browse files
Initial commit
parents
Pipeline
#524
canceled with stages
Changes
137
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2720 additions
and
0 deletions
+2720
-0
ldm/modules/losses/contperceptual.py
ldm/modules/losses/contperceptual.py
+111
-0
ldm/modules/losses/vqperceptual.py
ldm/modules/losses/vqperceptual.py
+167
-0
ldm/modules/x_transformer.py
ldm/modules/x_transformer.py
+641
-0
ldm/util.py
ldm/util.py
+203
-0
main.py
main.py
+741
-0
models/first_stage_models/kl-f16/config.yaml
models/first_stage_models/kl-f16/config.yaml
+44
-0
models/first_stage_models/kl-f32/config.yaml
models/first_stage_models/kl-f32/config.yaml
+46
-0
models/first_stage_models/kl-f4/config.yaml
models/first_stage_models/kl-f4/config.yaml
+41
-0
models/first_stage_models/kl-f8/config.yaml
models/first_stage_models/kl-f8/config.yaml
+42
-0
models/first_stage_models/vq-f16/config.yaml
models/first_stage_models/vq-f16/config.yaml
+49
-0
models/first_stage_models/vq-f4-noattn/config.yaml
models/first_stage_models/vq-f4-noattn/config.yaml
+46
-0
models/first_stage_models/vq-f4/config.yaml
models/first_stage_models/vq-f4/config.yaml
+45
-0
models/first_stage_models/vq-f8-n256/config.yaml
models/first_stage_models/vq-f8-n256/config.yaml
+48
-0
models/first_stage_models/vq-f8/config.yaml
models/first_stage_models/vq-f8/config.yaml
+48
-0
models/ldm/bsr_sr/config.yaml
models/ldm/bsr_sr/config.yaml
+80
-0
models/ldm/celeba256/config.yaml
models/ldm/celeba256/config.yaml
+70
-0
models/ldm/cin256/config.yaml
models/ldm/cin256/config.yaml
+80
-0
models/ldm/ffhq256/config.yaml
models/ldm/ffhq256/config.yaml
+70
-0
models/ldm/inpainting_big/config.yaml
models/ldm/inpainting_big/config.yaml
+67
-0
models/ldm/layout2img-openimages256/config.yaml
models/ldm/layout2img-openimages256/config.yaml
+81
-0
No files found.
ldm/modules/losses/contperceptual.py
0 → 100644
View file @
86685e45
import
torch
import
torch.nn
as
nn
from
taming.modules.losses.vqperceptual
import
*
# TODO: taming dependency yes/no?
class
LPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
logvar_init
=
0.0
,
kl_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_loss
=
"hinge"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
self
.
kl_weight
=
kl_weight
self
.
pixel_weight
=
pixelloss_weight
self
.
perceptual_loss
=
LPIPS
().
eval
()
self
.
perceptual_weight
=
perceptual_weight
# output log variance
self
.
logvar
=
nn
.
Parameter
(
torch
.
ones
(
size
=
())
*
logvar_init
)
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
self
.
disc_loss
=
hinge_d_loss
if
disc_loss
==
"hinge"
else
vanilla_d_loss
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
inputs
,
reconstructions
,
posteriors
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
weights
=
None
):
rec_loss
=
torch
.
abs
(
inputs
.
contiguous
()
-
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
nll_loss
=
rec_loss
/
torch
.
exp
(
self
.
logvar
)
+
self
.
logvar
weighted_nll_loss
=
nll_loss
if
weights
is
not
None
:
weighted_nll_loss
=
weights
*
nll_loss
weighted_nll_loss
=
torch
.
sum
(
weighted_nll_loss
)
/
weighted_nll_loss
.
shape
[
0
]
nll_loss
=
torch
.
sum
(
nll_loss
)
/
nll_loss
.
shape
[
0
]
kl_loss
=
posteriors
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
if
self
.
disc_factor
>
0.0
:
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
else
:
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
weighted_nll_loss
+
self
.
kl_weight
*
kl_loss
+
d_weight
*
disc_factor
*
g_loss
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/logvar"
.
format
(
split
):
self
.
logvar
.
detach
(),
"{}/kl_loss"
.
format
(
split
):
kl_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
ldm/modules/losses/vqperceptual.py
0 → 100644
View file @
86685e45
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
einops
import
repeat
from
taming.modules.discriminator.model
import
NLayerDiscriminator
,
weights_init
from
taming.modules.losses.lpips
import
LPIPS
from
taming.modules.losses.vqperceptual
import
hinge_d_loss
,
vanilla_d_loss
def
hinge_d_loss_with_exemplar_weights
(
logits_real
,
logits_fake
,
weights
):
assert
weights
.
shape
[
0
]
==
logits_real
.
shape
[
0
]
==
logits_fake
.
shape
[
0
]
loss_real
=
torch
.
mean
(
F
.
relu
(
1.
-
logits_real
),
dim
=
[
1
,
2
,
3
])
loss_fake
=
torch
.
mean
(
F
.
relu
(
1.
+
logits_fake
),
dim
=
[
1
,
2
,
3
])
loss_real
=
(
weights
*
loss_real
).
sum
()
/
weights
.
sum
()
loss_fake
=
(
weights
*
loss_fake
).
sum
()
/
weights
.
sum
()
d_loss
=
0.5
*
(
loss_real
+
loss_fake
)
return
d_loss
def
adopt_weight
(
weight
,
global_step
,
threshold
=
0
,
value
=
0.
):
if
global_step
<
threshold
:
weight
=
value
return
weight
def
measure_perplexity
(
predicted_indices
,
n_embed
):
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings
=
F
.
one_hot
(
predicted_indices
,
n_embed
).
float
().
reshape
(
-
1
,
n_embed
)
avg_probs
=
encodings
.
mean
(
0
)
perplexity
=
(
-
(
avg_probs
*
torch
.
log
(
avg_probs
+
1e-10
)).
sum
()).
exp
()
cluster_use
=
torch
.
sum
(
avg_probs
>
0
)
return
perplexity
,
cluster_use
def
l1
(
x
,
y
):
return
torch
.
abs
(
x
-
y
)
def
l2
(
x
,
y
):
return
torch
.
pow
((
x
-
y
),
2
)
class
VQLPIPSWithDiscriminator
(
nn
.
Module
):
def
__init__
(
self
,
disc_start
,
codebook_weight
=
1.0
,
pixelloss_weight
=
1.0
,
disc_num_layers
=
3
,
disc_in_channels
=
3
,
disc_factor
=
1.0
,
disc_weight
=
1.0
,
perceptual_weight
=
1.0
,
use_actnorm
=
False
,
disc_conditional
=
False
,
disc_ndf
=
64
,
disc_loss
=
"hinge"
,
n_classes
=
None
,
perceptual_loss
=
"lpips"
,
pixel_loss
=
"l1"
):
super
().
__init__
()
assert
disc_loss
in
[
"hinge"
,
"vanilla"
]
assert
perceptual_loss
in
[
"lpips"
,
"clips"
,
"dists"
]
assert
pixel_loss
in
[
"l1"
,
"l2"
]
self
.
codebook_weight
=
codebook_weight
self
.
pixel_weight
=
pixelloss_weight
if
perceptual_loss
==
"lpips"
:
print
(
f
"
{
self
.
__class__
.
__name__
}
: Running with LPIPS."
)
self
.
perceptual_loss
=
LPIPS
().
eval
()
else
:
raise
ValueError
(
f
"Unknown perceptual loss: >>
{
perceptual_loss
}
<<"
)
self
.
perceptual_weight
=
perceptual_weight
if
pixel_loss
==
"l1"
:
self
.
pixel_loss
=
l1
else
:
self
.
pixel_loss
=
l2
self
.
discriminator
=
NLayerDiscriminator
(
input_nc
=
disc_in_channels
,
n_layers
=
disc_num_layers
,
use_actnorm
=
use_actnorm
,
ndf
=
disc_ndf
).
apply
(
weights_init
)
self
.
discriminator_iter_start
=
disc_start
if
disc_loss
==
"hinge"
:
self
.
disc_loss
=
hinge_d_loss
elif
disc_loss
==
"vanilla"
:
self
.
disc_loss
=
vanilla_d_loss
else
:
raise
ValueError
(
f
"Unknown GAN loss '
{
disc_loss
}
'."
)
print
(
f
"VQLPIPSWithDiscriminator running with
{
disc_loss
}
loss."
)
self
.
disc_factor
=
disc_factor
self
.
discriminator_weight
=
disc_weight
self
.
disc_conditional
=
disc_conditional
self
.
n_classes
=
n_classes
def
calculate_adaptive_weight
(
self
,
nll_loss
,
g_loss
,
last_layer
=
None
):
if
last_layer
is
not
None
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
last_layer
,
retain_graph
=
True
)[
0
]
else
:
nll_grads
=
torch
.
autograd
.
grad
(
nll_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
g_grads
=
torch
.
autograd
.
grad
(
g_loss
,
self
.
last_layer
[
0
],
retain_graph
=
True
)[
0
]
d_weight
=
torch
.
norm
(
nll_grads
)
/
(
torch
.
norm
(
g_grads
)
+
1e-4
)
d_weight
=
torch
.
clamp
(
d_weight
,
0.0
,
1e4
).
detach
()
d_weight
=
d_weight
*
self
.
discriminator_weight
return
d_weight
def
forward
(
self
,
codebook_loss
,
inputs
,
reconstructions
,
optimizer_idx
,
global_step
,
last_layer
=
None
,
cond
=
None
,
split
=
"train"
,
predicted_indices
=
None
):
if
not
exists
(
codebook_loss
):
codebook_loss
=
torch
.
tensor
([
0.
]).
to
(
inputs
.
device
)
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss
=
self
.
pixel_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
if
self
.
perceptual_weight
>
0
:
p_loss
=
self
.
perceptual_loss
(
inputs
.
contiguous
(),
reconstructions
.
contiguous
())
rec_loss
=
rec_loss
+
self
.
perceptual_weight
*
p_loss
else
:
p_loss
=
torch
.
tensor
([
0.0
])
nll_loss
=
rec_loss
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss
=
torch
.
mean
(
nll_loss
)
# now the GAN part
if
optimizer_idx
==
0
:
# generator update
if
cond
is
None
:
assert
not
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
())
else
:
assert
self
.
disc_conditional
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
(),
cond
),
dim
=
1
))
g_loss
=
-
torch
.
mean
(
logits_fake
)
try
:
d_weight
=
self
.
calculate_adaptive_weight
(
nll_loss
,
g_loss
,
last_layer
=
last_layer
)
except
RuntimeError
:
assert
not
self
.
training
d_weight
=
torch
.
tensor
(
0.0
)
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
loss
=
nll_loss
+
d_weight
*
disc_factor
*
g_loss
+
self
.
codebook_weight
*
codebook_loss
.
mean
()
log
=
{
"{}/total_loss"
.
format
(
split
):
loss
.
clone
().
detach
().
mean
(),
"{}/quant_loss"
.
format
(
split
):
codebook_loss
.
detach
().
mean
(),
"{}/nll_loss"
.
format
(
split
):
nll_loss
.
detach
().
mean
(),
"{}/rec_loss"
.
format
(
split
):
rec_loss
.
detach
().
mean
(),
"{}/p_loss"
.
format
(
split
):
p_loss
.
detach
().
mean
(),
"{}/d_weight"
.
format
(
split
):
d_weight
.
detach
(),
"{}/disc_factor"
.
format
(
split
):
torch
.
tensor
(
disc_factor
),
"{}/g_loss"
.
format
(
split
):
g_loss
.
detach
().
mean
(),
}
if
predicted_indices
is
not
None
:
assert
self
.
n_classes
is
not
None
with
torch
.
no_grad
():
perplexity
,
cluster_usage
=
measure_perplexity
(
predicted_indices
,
self
.
n_classes
)
log
[
f
"
{
split
}
/perplexity"
]
=
perplexity
log
[
f
"
{
split
}
/cluster_usage"
]
=
cluster_usage
return
loss
,
log
if
optimizer_idx
==
1
:
# second pass for discriminator update
if
cond
is
None
:
logits_real
=
self
.
discriminator
(
inputs
.
contiguous
().
detach
())
logits_fake
=
self
.
discriminator
(
reconstructions
.
contiguous
().
detach
())
else
:
logits_real
=
self
.
discriminator
(
torch
.
cat
((
inputs
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
logits_fake
=
self
.
discriminator
(
torch
.
cat
((
reconstructions
.
contiguous
().
detach
(),
cond
),
dim
=
1
))
disc_factor
=
adopt_weight
(
self
.
disc_factor
,
global_step
,
threshold
=
self
.
discriminator_iter_start
)
d_loss
=
disc_factor
*
self
.
disc_loss
(
logits_real
,
logits_fake
)
log
=
{
"{}/disc_loss"
.
format
(
split
):
d_loss
.
clone
().
detach
().
mean
(),
"{}/logits_real"
.
format
(
split
):
logits_real
.
detach
().
mean
(),
"{}/logits_fake"
.
format
(
split
):
logits_fake
.
detach
().
mean
()
}
return
d_loss
,
log
ldm/modules/x_transformer.py
0 → 100644
View file @
86685e45
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import
torch
from
torch
import
nn
,
einsum
import
torch.nn.functional
as
F
from
functools
import
partial
from
inspect
import
isfunction
from
collections
import
namedtuple
from
einops
import
rearrange
,
repeat
,
reduce
# constants
DEFAULT_DIM_HEAD
=
64
Intermediates
=
namedtuple
(
'Intermediates'
,
[
'pre_softmax_attn'
,
'post_softmax_attn'
])
LayerIntermediates
=
namedtuple
(
'Intermediates'
,
[
'hiddens'
,
'attn_intermediates'
])
class
AbsolutePositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_seq_len
):
super
().
__init__
()
self
.
emb
=
nn
.
Embedding
(
max_seq_len
,
dim
)
self
.
init_
()
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
):
n
=
torch
.
arange
(
x
.
shape
[
1
],
device
=
x
.
device
)
return
self
.
emb
(
n
)[
None
,
:,
:]
class
FixedPositionalEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
def
forward
(
self
,
x
,
seq_dim
=
1
,
offset
=
0
):
t
=
torch
.
arange
(
x
.
shape
[
seq_dim
],
device
=
x
.
device
).
type_as
(
self
.
inv_freq
)
+
offset
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
sinusoid_inp
.
sin
(),
sinusoid_inp
.
cos
()),
dim
=-
1
)
return
emb
[
None
,
:,
:]
# helpers
def
exists
(
val
):
return
val
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
always
(
val
):
def
inner
(
*
args
,
**
kwargs
):
return
val
return
inner
def
not_equals
(
val
):
def
inner
(
x
):
return
x
!=
val
return
inner
def
equals
(
val
):
def
inner
(
x
):
return
x
==
val
return
inner
def
max_neg_value
(
tensor
):
return
-
torch
.
finfo
(
tensor
.
dtype
).
max
# keyword argument helpers
def
pick_and_pop
(
keys
,
d
):
values
=
list
(
map
(
lambda
key
:
d
.
pop
(
key
),
keys
))
return
dict
(
zip
(
keys
,
values
))
def
group_dict_by_key
(
cond
,
d
):
return_val
=
[
dict
(),
dict
()]
for
key
in
d
.
keys
():
match
=
bool
(
cond
(
key
))
ind
=
int
(
not
match
)
return_val
[
ind
][
key
]
=
d
[
key
]
return
(
*
return_val
,)
def
string_begins_with
(
prefix
,
str
):
return
str
.
startswith
(
prefix
)
def
group_by_key_prefix
(
prefix
,
d
):
return
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
def
groupby_prefix_and_trim
(
prefix
,
d
):
kwargs_with_prefix
,
kwargs
=
group_dict_by_key
(
partial
(
string_begins_with
,
prefix
),
d
)
kwargs_without_prefix
=
dict
(
map
(
lambda
x
:
(
x
[
0
][
len
(
prefix
):],
x
[
1
]),
tuple
(
kwargs_with_prefix
.
items
())))
return
kwargs_without_prefix
,
kwargs
# classes
class
Scale
(
nn
.
Module
):
def
__init__
(
self
,
value
,
fn
):
super
().
__init__
()
self
.
value
=
value
self
.
fn
=
fn
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
value
,
*
rest
)
class
Rezero
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
self
.
g
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
,
**
kwargs
):
x
,
*
rest
=
self
.
fn
(
x
,
**
kwargs
)
return
(
x
*
self
.
g
,
*
rest
)
class
ScaleNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-8
):
super
().
__init__
()
self
.
scale
=
dim
**
-
0.5
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
norm
=
torch
.
norm
(
x
,
dim
=-
1
,
keepdim
=
True
)
*
self
.
scale
return
x
/
norm
.
clamp
(
min
=
self
.
eps
)
*
self
.
g
class
Residual
(
nn
.
Module
):
def
forward
(
self
,
x
,
residual
):
return
x
+
residual
class
GRUGating
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
gru
=
nn
.
GRUCell
(
dim
,
dim
)
def
forward
(
self
,
x
,
residual
):
gated_output
=
self
.
gru
(
rearrange
(
x
,
'b n d -> (b n) d'
),
rearrange
(
residual
,
'b n d -> (b n) d'
)
)
return
gated_output
.
reshape_as
(
x
)
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
=
None
,
mult
=
4
,
glu
=
False
,
dropout
=
0.
):
super
().
__init__
()
inner_dim
=
int
(
dim
*
mult
)
dim_out
=
default
(
dim_out
,
dim
)
project_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
inner_dim
),
nn
.
GELU
()
)
if
not
glu
else
GEGLU
(
dim
,
inner_dim
)
self
.
net
=
nn
.
Sequential
(
project_in
,
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
inner_dim
,
dim_out
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
# attention.
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_head
=
DEFAULT_DIM_HEAD
,
heads
=
8
,
causal
=
False
,
mask
=
None
,
talking_heads
=
False
,
sparse_topk
=
None
,
use_entmax15
=
False
,
num_mem_kv
=
0
,
dropout
=
0.
,
on_attn
=
False
):
super
().
__init__
()
if
use_entmax15
:
raise
NotImplementedError
(
"Check out entmax activation instead of softmax activation!"
)
self
.
scale
=
dim_head
**
-
0.5
self
.
heads
=
heads
self
.
causal
=
causal
self
.
mask
=
mask
inner_dim
=
dim_head
*
heads
self
.
to_q
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
dim
,
inner_dim
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# talking heads
self
.
talking_heads
=
talking_heads
if
talking_heads
:
self
.
pre_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
self
.
post_softmax_proj
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
heads
))
# explicit topk sparse attention
self
.
sparse_topk
=
sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
self
.
attn_fn
=
F
.
softmax
# add memory key / values
self
.
num_mem_kv
=
num_mem_kv
if
num_mem_kv
>
0
:
self
.
mem_k
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
self
.
mem_v
=
nn
.
Parameter
(
torch
.
randn
(
heads
,
num_mem_kv
,
dim_head
))
# attention on attention
self
.
attn_on_attn
=
on_attn
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
dim
*
2
),
nn
.
GLU
())
if
on_attn
else
nn
.
Linear
(
inner_dim
,
dim
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
rel_pos
=
None
,
sinusoidal_emb
=
None
,
prev_attn
=
None
,
mem
=
None
):
b
,
n
,
_
,
h
,
talking_heads
,
device
=
*
x
.
shape
,
self
.
heads
,
self
.
talking_heads
,
x
.
device
kv_input
=
default
(
context
,
x
)
q_input
=
x
k_input
=
kv_input
v_input
=
kv_input
if
exists
(
mem
):
k_input
=
torch
.
cat
((
mem
,
k_input
),
dim
=-
2
)
v_input
=
torch
.
cat
((
mem
,
v_input
),
dim
=-
2
)
if
exists
(
sinusoidal_emb
):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset
=
k_input
.
shape
[
-
2
]
-
q_input
.
shape
[
-
2
]
q_input
=
q_input
+
sinusoidal_emb
(
q_input
,
offset
=
offset
)
k_input
=
k_input
+
sinusoidal_emb
(
k_input
)
q
=
self
.
to_q
(
q_input
)
k
=
self
.
to_k
(
k_input
)
v
=
self
.
to_v
(
v_input
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> b h n d'
,
h
=
h
),
(
q
,
k
,
v
))
input_mask
=
None
if
any
(
map
(
exists
,
(
mask
,
context_mask
))):
q_mask
=
default
(
mask
,
lambda
:
torch
.
ones
((
b
,
n
),
device
=
device
).
bool
())
k_mask
=
q_mask
if
not
exists
(
context
)
else
context_mask
k_mask
=
default
(
k_mask
,
lambda
:
torch
.
ones
((
b
,
k
.
shape
[
-
2
]),
device
=
device
).
bool
())
q_mask
=
rearrange
(
q_mask
,
'b i -> b () i ()'
)
k_mask
=
rearrange
(
k_mask
,
'b j -> b () () j'
)
input_mask
=
q_mask
*
k_mask
if
self
.
num_mem_kv
>
0
:
mem_k
,
mem_v
=
map
(
lambda
t
:
repeat
(
t
,
'h n d -> b h n d'
,
b
=
b
),
(
self
.
mem_k
,
self
.
mem_v
))
k
=
torch
.
cat
((
mem_k
,
k
),
dim
=-
2
)
v
=
torch
.
cat
((
mem_v
,
v
),
dim
=-
2
)
if
exists
(
input_mask
):
input_mask
=
F
.
pad
(
input_mask
,
(
self
.
num_mem_kv
,
0
),
value
=
True
)
dots
=
einsum
(
'b h i d, b h j d -> b h i j'
,
q
,
k
)
*
self
.
scale
mask_value
=
max_neg_value
(
dots
)
if
exists
(
prev_attn
):
dots
=
dots
+
prev_attn
pre_softmax_attn
=
dots
if
talking_heads
:
dots
=
einsum
(
'b h i j, h k -> b k i j'
,
dots
,
self
.
pre_softmax_proj
).
contiguous
()
if
exists
(
rel_pos
):
dots
=
rel_pos
(
dots
)
if
exists
(
input_mask
):
dots
.
masked_fill_
(
~
input_mask
,
mask_value
)
del
input_mask
if
self
.
causal
:
i
,
j
=
dots
.
shape
[
-
2
:]
r
=
torch
.
arange
(
i
,
device
=
device
)
mask
=
rearrange
(
r
,
'i -> () () i ()'
)
<
rearrange
(
r
,
'j -> () () () j'
)
mask
=
F
.
pad
(
mask
,
(
j
-
i
,
0
),
value
=
False
)
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
if
exists
(
self
.
sparse_topk
)
and
self
.
sparse_topk
<
dots
.
shape
[
-
1
]:
top
,
_
=
dots
.
topk
(
self
.
sparse_topk
,
dim
=-
1
)
vk
=
top
[...,
-
1
].
unsqueeze
(
-
1
).
expand_as
(
dots
)
mask
=
dots
<
vk
dots
.
masked_fill_
(
mask
,
mask_value
)
del
mask
attn
=
self
.
attn_fn
(
dots
,
dim
=-
1
)
post_softmax_attn
=
attn
attn
=
self
.
dropout
(
attn
)
if
talking_heads
:
attn
=
einsum
(
'b h i j, h k -> b k i j'
,
attn
,
self
.
post_softmax_proj
).
contiguous
()
out
=
einsum
(
'b h i j, b h j d -> b h i d'
,
attn
,
v
)
out
=
rearrange
(
out
,
'b h n d -> b n (h d)'
)
intermediates
=
Intermediates
(
pre_softmax_attn
=
pre_softmax_attn
,
post_softmax_attn
=
post_softmax_attn
)
return
self
.
to_out
(
out
),
intermediates
class
AttentionLayers
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
depth
,
heads
=
8
,
causal
=
False
,
cross_attend
=
False
,
only_cross
=
False
,
use_scalenorm
=
False
,
use_rmsnorm
=
False
,
use_rezero
=
False
,
rel_pos_num_buckets
=
32
,
rel_pos_max_distance
=
128
,
position_infused_attn
=
False
,
custom_layers
=
None
,
sandwich_coef
=
None
,
par_ratio
=
None
,
residual_attn
=
False
,
cross_residual_attn
=
False
,
macaron
=
False
,
pre_norm
=
True
,
gate_residual
=
False
,
**
kwargs
):
super
().
__init__
()
ff_kwargs
,
kwargs
=
groupby_prefix_and_trim
(
'ff_'
,
kwargs
)
attn_kwargs
,
_
=
groupby_prefix_and_trim
(
'attn_'
,
kwargs
)
dim_head
=
attn_kwargs
.
get
(
'dim_head'
,
DEFAULT_DIM_HEAD
)
self
.
dim
=
dim
self
.
depth
=
depth
self
.
layers
=
nn
.
ModuleList
([])
self
.
has_pos_emb
=
position_infused_attn
self
.
pia_pos_emb
=
FixedPositionalEmbedding
(
dim
)
if
position_infused_attn
else
None
self
.
rotary_pos_emb
=
always
(
None
)
assert
rel_pos_num_buckets
<=
rel_pos_max_distance
,
'number of relative position buckets must be less than the relative position max distance'
self
.
rel_pos
=
None
self
.
pre_norm
=
pre_norm
self
.
residual_attn
=
residual_attn
self
.
cross_residual_attn
=
cross_residual_attn
norm_class
=
ScaleNorm
if
use_scalenorm
else
nn
.
LayerNorm
norm_class
=
RMSNorm
if
use_rmsnorm
else
norm_class
norm_fn
=
partial
(
norm_class
,
dim
)
norm_fn
=
nn
.
Identity
if
use_rezero
else
norm_fn
branch_fn
=
Rezero
if
use_rezero
else
None
if
cross_attend
and
not
only_cross
:
default_block
=
(
'a'
,
'c'
,
'f'
)
elif
cross_attend
and
only_cross
:
default_block
=
(
'c'
,
'f'
)
else
:
default_block
=
(
'a'
,
'f'
)
if
macaron
:
default_block
=
(
'f'
,)
+
default_block
if
exists
(
custom_layers
):
layer_types
=
custom_layers
elif
exists
(
par_ratio
):
par_depth
=
depth
*
len
(
default_block
)
assert
1
<
par_ratio
<=
par_depth
,
'par ratio out of range'
default_block
=
tuple
(
filter
(
not_equals
(
'f'
),
default_block
))
par_attn
=
par_depth
//
par_ratio
depth_cut
=
par_depth
*
2
//
3
# 2 / 3 attention layer cutoff suggested by PAR paper
par_width
=
(
depth_cut
+
depth_cut
//
par_attn
)
//
par_attn
assert
len
(
default_block
)
<=
par_width
,
'default block is too large for par_ratio'
par_block
=
default_block
+
(
'f'
,)
*
(
par_width
-
len
(
default_block
))
par_head
=
par_block
*
par_attn
layer_types
=
par_head
+
(
'f'
,)
*
(
par_depth
-
len
(
par_head
))
elif
exists
(
sandwich_coef
):
assert
sandwich_coef
>
0
and
sandwich_coef
<=
depth
,
'sandwich coefficient should be less than the depth'
layer_types
=
(
'a'
,)
*
sandwich_coef
+
default_block
*
(
depth
-
sandwich_coef
)
+
(
'f'
,)
*
sandwich_coef
else
:
layer_types
=
default_block
*
depth
self
.
layer_types
=
layer_types
self
.
num_attn_layers
=
len
(
list
(
filter
(
equals
(
'a'
),
layer_types
)))
for
layer_type
in
self
.
layer_types
:
if
layer_type
==
'a'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
causal
=
causal
,
**
attn_kwargs
)
elif
layer_type
==
'c'
:
layer
=
Attention
(
dim
,
heads
=
heads
,
**
attn_kwargs
)
elif
layer_type
==
'f'
:
layer
=
FeedForward
(
dim
,
**
ff_kwargs
)
layer
=
layer
if
not
macaron
else
Scale
(
0.5
,
layer
)
else
:
raise
Exception
(
f
'invalid layer type
{
layer_type
}
'
)
if
isinstance
(
layer
,
Attention
)
and
exists
(
branch_fn
):
layer
=
branch_fn
(
layer
)
if
gate_residual
:
residual_fn
=
GRUGating
(
dim
)
else
:
residual_fn
=
Residual
()
self
.
layers
.
append
(
nn
.
ModuleList
([
norm_fn
(),
layer
,
residual_fn
]))
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
context_mask
=
None
,
mems
=
None
,
return_hiddens
=
False
):
hiddens
=
[]
intermediates
=
[]
prev_attn
=
None
prev_cross_attn
=
None
mems
=
mems
.
copy
()
if
exists
(
mems
)
else
[
None
]
*
self
.
num_attn_layers
for
ind
,
(
layer_type
,
(
norm
,
block
,
residual_fn
))
in
enumerate
(
zip
(
self
.
layer_types
,
self
.
layers
)):
is_last
=
ind
==
(
len
(
self
.
layers
)
-
1
)
if
layer_type
==
'a'
:
hiddens
.
append
(
x
)
layer_mem
=
mems
.
pop
(
0
)
residual
=
x
if
self
.
pre_norm
:
x
=
norm
(
x
)
if
layer_type
==
'a'
:
out
,
inter
=
block
(
x
,
mask
=
mask
,
sinusoidal_emb
=
self
.
pia_pos_emb
,
rel_pos
=
self
.
rel_pos
,
prev_attn
=
prev_attn
,
mem
=
layer_mem
)
elif
layer_type
==
'c'
:
out
,
inter
=
block
(
x
,
context
=
context
,
mask
=
mask
,
context_mask
=
context_mask
,
prev_attn
=
prev_cross_attn
)
elif
layer_type
==
'f'
:
out
=
block
(
x
)
x
=
residual_fn
(
out
,
residual
)
if
layer_type
in
(
'a'
,
'c'
):
intermediates
.
append
(
inter
)
if
layer_type
==
'a'
and
self
.
residual_attn
:
prev_attn
=
inter
.
pre_softmax_attn
elif
layer_type
==
'c'
and
self
.
cross_residual_attn
:
prev_cross_attn
=
inter
.
pre_softmax_attn
if
not
self
.
pre_norm
and
not
is_last
:
x
=
norm
(
x
)
if
return_hiddens
:
intermediates
=
LayerIntermediates
(
hiddens
=
hiddens
,
attn_intermediates
=
intermediates
)
return
x
,
intermediates
return
x
class
Encoder
(
AttentionLayers
):
def
__init__
(
self
,
**
kwargs
):
assert
'causal'
not
in
kwargs
,
'cannot set causality on encoder'
super
().
__init__
(
causal
=
False
,
**
kwargs
)
class
TransformerWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
num_tokens
,
max_seq_len
,
attn_layers
,
emb_dim
=
None
,
max_mem_len
=
0.
,
emb_dropout
=
0.
,
num_memory_tokens
=
None
,
tie_embedding
=
False
,
use_pos_emb
=
True
):
super
().
__init__
()
assert
isinstance
(
attn_layers
,
AttentionLayers
),
'attention layers must be one of Encoder or Decoder'
dim
=
attn_layers
.
dim
emb_dim
=
default
(
emb_dim
,
dim
)
self
.
max_seq_len
=
max_seq_len
self
.
max_mem_len
=
max_mem_len
self
.
num_tokens
=
num_tokens
self
.
token_emb
=
nn
.
Embedding
(
num_tokens
,
emb_dim
)
self
.
pos_emb
=
AbsolutePositionalEmbedding
(
emb_dim
,
max_seq_len
)
if
(
use_pos_emb
and
not
attn_layers
.
has_pos_emb
)
else
always
(
0
)
self
.
emb_dropout
=
nn
.
Dropout
(
emb_dropout
)
self
.
project_emb
=
nn
.
Linear
(
emb_dim
,
dim
)
if
emb_dim
!=
dim
else
nn
.
Identity
()
self
.
attn_layers
=
attn_layers
self
.
norm
=
nn
.
LayerNorm
(
dim
)
self
.
init_
()
self
.
to_logits
=
nn
.
Linear
(
dim
,
num_tokens
)
if
not
tie_embedding
else
lambda
t
:
t
@
self
.
token_emb
.
weight
.
t
()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens
=
default
(
num_memory_tokens
,
0
)
self
.
num_memory_tokens
=
num_memory_tokens
if
num_memory_tokens
>
0
:
self
.
memory_tokens
=
nn
.
Parameter
(
torch
.
randn
(
num_memory_tokens
,
dim
))
# let funnel encoder know number of memory tokens, if specified
if
hasattr
(
attn_layers
,
'num_memory_tokens'
):
attn_layers
.
num_memory_tokens
=
num_memory_tokens
def
init_
(
self
):
nn
.
init
.
normal_
(
self
.
token_emb
.
weight
,
std
=
0.02
)
def
forward
(
self
,
x
,
return_embeddings
=
False
,
mask
=
None
,
return_mems
=
False
,
return_attn
=
False
,
mems
=
None
,
**
kwargs
):
b
,
n
,
device
,
num_mem
=
*
x
.
shape
,
x
.
device
,
self
.
num_memory_tokens
x
=
self
.
token_emb
(
x
)
x
+=
self
.
pos_emb
(
x
)
x
=
self
.
emb_dropout
(
x
)
x
=
self
.
project_emb
(
x
)
if
num_mem
>
0
:
mem
=
repeat
(
self
.
memory_tokens
,
'n d -> b n d'
,
b
=
b
)
x
=
torch
.
cat
((
mem
,
x
),
dim
=
1
)
# auto-handle masking after appending memory tokens
if
exists
(
mask
):
mask
=
F
.
pad
(
mask
,
(
num_mem
,
0
),
value
=
True
)
x
,
intermediates
=
self
.
attn_layers
(
x
,
mask
=
mask
,
mems
=
mems
,
return_hiddens
=
True
,
**
kwargs
)
x
=
self
.
norm
(
x
)
mem
,
x
=
x
[:,
:
num_mem
],
x
[:,
num_mem
:]
out
=
self
.
to_logits
(
x
)
if
not
return_embeddings
else
x
if
return_mems
:
hiddens
=
intermediates
.
hiddens
new_mems
=
list
(
map
(
lambda
pair
:
torch
.
cat
(
pair
,
dim
=-
2
),
zip
(
mems
,
hiddens
)))
if
exists
(
mems
)
else
hiddens
new_mems
=
list
(
map
(
lambda
t
:
t
[...,
-
self
.
max_mem_len
:,
:].
detach
(),
new_mems
))
return
out
,
new_mems
if
return_attn
:
attn_maps
=
list
(
map
(
lambda
t
:
t
.
post_softmax_attn
,
intermediates
.
attn_intermediates
))
return
out
,
attn_maps
return
out
ldm/util.py
0 → 100644
View file @
86685e45
import
importlib
import
torch
import
numpy
as
np
from
collections
import
abc
from
einops
import
rearrange
from
functools
import
partial
import
multiprocessing
as
mp
from
threading
import
Thread
from
queue
import
Queue
from
inspect
import
isfunction
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
log_txt_as_img
(
wh
,
xc
,
size
=
10
):
# wh a tuple of (width, height)
# xc a list of captions to plot
b
=
len
(
xc
)
txts
=
list
()
for
bi
in
range
(
b
):
txt
=
Image
.
new
(
"RGB"
,
wh
,
color
=
"white"
)
draw
=
ImageDraw
.
Draw
(
txt
)
font
=
ImageFont
.
truetype
(
'data/DejaVuSans.ttf'
,
size
=
size
)
nc
=
int
(
40
*
(
wh
[
0
]
/
256
))
lines
=
"
\n
"
.
join
(
xc
[
bi
][
start
:
start
+
nc
]
for
start
in
range
(
0
,
len
(
xc
[
bi
]),
nc
))
try
:
draw
.
text
((
0
,
0
),
lines
,
fill
=
"black"
,
font
=
font
)
except
UnicodeEncodeError
:
print
(
"Cant encode string for logging. Skipping."
)
txt
=
np
.
array
(
txt
).
transpose
(
2
,
0
,
1
)
/
127.5
-
1.0
txts
.
append
(
txt
)
txts
=
np
.
stack
(
txts
)
txts
=
torch
.
tensor
(
txts
)
return
txts
def
ismap
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
>
3
)
def
isimage
(
x
):
if
not
isinstance
(
x
,
torch
.
Tensor
):
return
False
return
(
len
(
x
.
shape
)
==
4
)
and
(
x
.
shape
[
1
]
==
3
or
x
.
shape
[
1
]
==
1
)
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
mean_flat
(
tensor
):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return
tensor
.
mean
(
dim
=
list
(
range
(
1
,
len
(
tensor
.
shape
))))
def
count_params
(
model
,
verbose
=
False
):
total_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
if
verbose
:
print
(
f
"
{
model
.
__class__
.
__name__
}
has
{
total_params
*
1.e-6
:.
2
f
}
M params."
)
return
total_params
def
instantiate_from_config
(
config
):
if
not
"target"
in
config
:
if
config
==
'__is_first_stage__'
:
return
None
elif
config
==
"__is_unconditional__"
:
return
None
raise
KeyError
(
"Expected key `target` to instantiate."
)
return
get_obj_from_str
(
config
[
"target"
])(
**
config
.
get
(
"params"
,
dict
()))
def
get_obj_from_str
(
string
,
reload
=
False
):
module
,
cls
=
string
.
rsplit
(
"."
,
1
)
if
reload
:
module_imp
=
importlib
.
import_module
(
module
)
importlib
.
reload
(
module_imp
)
return
getattr
(
importlib
.
import_module
(
module
,
package
=
None
),
cls
)
def
_do_parallel_data_prefetch
(
func
,
Q
,
data
,
idx
,
idx_to_fn
=
False
):
# create dummy dataset instance
# run prefetching
if
idx_to_fn
:
res
=
func
(
data
,
worker_id
=
idx
)
else
:
res
=
func
(
data
)
Q
.
put
([
idx
,
res
])
Q
.
put
(
"Done"
)
def
parallel_data_prefetch
(
func
:
callable
,
data
,
n_proc
,
target_data_type
=
"ndarray"
,
cpu_intensive
=
True
,
use_worker_id
=
False
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if
isinstance
(
data
,
np
.
ndarray
)
and
target_data_type
==
"list"
:
raise
ValueError
(
"list expected but function got ndarray."
)
elif
isinstance
(
data
,
abc
.
Iterable
):
if
isinstance
(
data
,
dict
):
print
(
f
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data
=
list
(
data
.
values
())
if
target_data_type
==
"ndarray"
:
data
=
np
.
asarray
(
data
)
else
:
data
=
list
(
data
)
else
:
raise
TypeError
(
f
"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually
{
type
(
data
)
}
."
)
if
cpu_intensive
:
Q
=
mp
.
Queue
(
1000
)
proc
=
mp
.
Process
else
:
Q
=
Queue
(
1000
)
proc
=
Thread
# spawn processes
if
target_data_type
==
"ndarray"
:
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
np
.
array_split
(
data
,
n_proc
))
]
else
:
step
=
(
int
(
len
(
data
)
/
n_proc
+
1
)
if
len
(
data
)
%
n_proc
!=
0
else
int
(
len
(
data
)
/
n_proc
)
)
arguments
=
[
[
func
,
Q
,
part
,
i
,
use_worker_id
]
for
i
,
part
in
enumerate
(
[
data
[
i
:
i
+
step
]
for
i
in
range
(
0
,
len
(
data
),
step
)]
)
]
processes
=
[]
for
i
in
range
(
n_proc
):
p
=
proc
(
target
=
_do_parallel_data_prefetch
,
args
=
arguments
[
i
])
processes
+=
[
p
]
# start processes
print
(
f
"Start prefetching..."
)
import
time
start
=
time
.
time
()
gather_res
=
[[]
for
_
in
range
(
n_proc
)]
try
:
for
p
in
processes
:
p
.
start
()
k
=
0
while
k
<
n_proc
:
# get result
res
=
Q
.
get
()
if
res
==
"Done"
:
k
+=
1
else
:
gather_res
[
res
[
0
]]
=
res
[
1
]
except
Exception
as
e
:
print
(
"Exception: "
,
e
)
for
p
in
processes
:
p
.
terminate
()
raise
e
finally
:
for
p
in
processes
:
p
.
join
()
print
(
f
"Prefetching complete. [
{
time
.
time
()
-
start
}
sec.]"
)
if
target_data_type
==
'ndarray'
:
if
not
isinstance
(
gather_res
[
0
],
np
.
ndarray
):
return
np
.
concatenate
([
np
.
asarray
(
r
)
for
r
in
gather_res
],
axis
=
0
)
# order outputs
return
np
.
concatenate
(
gather_res
,
axis
=
0
)
elif
target_data_type
==
'list'
:
out
=
[]
for
r
in
gather_res
:
out
.
extend
(
r
)
return
out
else
:
return
gather_res
main.py
0 → 100644
View file @
86685e45
import
argparse
,
os
,
sys
,
datetime
,
glob
,
importlib
,
csv
import
numpy
as
np
import
time
import
torch
import
torchvision
import
pytorch_lightning
as
pl
from
packaging
import
version
from
omegaconf
import
OmegaConf
from
torch.utils.data
import
random_split
,
DataLoader
,
Dataset
,
Subset
from
functools
import
partial
from
PIL
import
Image
from
pytorch_lightning
import
seed_everything
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.callbacks
import
ModelCheckpoint
,
Callback
,
LearningRateMonitor
from
pytorch_lightning.utilities.distributed
import
rank_zero_only
from
pytorch_lightning.utilities
import
rank_zero_info
from
ldm.data.base
import
Txt2ImgIterableBaseDataset
from
ldm.util
import
instantiate_from_config
def
get_parser
(
**
parser_kwargs
):
def
str2bool
(
v
):
if
isinstance
(
v
,
bool
):
return
v
if
v
.
lower
()
in
(
"yes"
,
"true"
,
"t"
,
"y"
,
"1"
):
return
True
elif
v
.
lower
()
in
(
"no"
,
"false"
,
"f"
,
"n"
,
"0"
):
return
False
else
:
raise
argparse
.
ArgumentTypeError
(
"Boolean value expected."
)
parser
=
argparse
.
ArgumentParser
(
**
parser_kwargs
)
parser
.
add_argument
(
"-n"
,
"--name"
,
type
=
str
,
const
=
True
,
default
=
""
,
nargs
=
"?"
,
help
=
"postfix for logdir"
,
)
parser
.
add_argument
(
"-r"
,
"--resume"
,
type
=
str
,
const
=
True
,
default
=
""
,
nargs
=
"?"
,
help
=
"resume from logdir or checkpoint in logdir"
,
)
parser
.
add_argument
(
"-b"
,
"--base"
,
nargs
=
"*"
,
metavar
=
"base_config.yaml"
,
help
=
"paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`."
,
default
=
list
(),
)
parser
.
add_argument
(
"-t"
,
"--train"
,
type
=
str2bool
,
const
=
True
,
default
=
False
,
nargs
=
"?"
,
help
=
"train"
,
)
parser
.
add_argument
(
"--no-test"
,
type
=
str2bool
,
const
=
True
,
default
=
False
,
nargs
=
"?"
,
help
=
"disable test"
,
)
parser
.
add_argument
(
"-p"
,
"--project"
,
help
=
"name of new or path to existing project"
)
parser
.
add_argument
(
"-d"
,
"--debug"
,
type
=
str2bool
,
nargs
=
"?"
,
const
=
True
,
default
=
False
,
help
=
"enable post-mortem debugging"
,
)
parser
.
add_argument
(
"-s"
,
"--seed"
,
type
=
int
,
default
=
23
,
help
=
"seed for seed_everything"
,
)
parser
.
add_argument
(
"-f"
,
"--postfix"
,
type
=
str
,
default
=
""
,
help
=
"post-postfix for default name"
,
)
parser
.
add_argument
(
"-l"
,
"--logdir"
,
type
=
str
,
default
=
"logs"
,
help
=
"directory for logging dat shit"
,
)
parser
.
add_argument
(
"--scale_lr"
,
type
=
str2bool
,
nargs
=
"?"
,
const
=
True
,
default
=
True
,
help
=
"scale base-lr by ngpu * batch_size * n_accumulate"
,
)
return
parser
def
nondefault_trainer_args
(
opt
):
parser
=
argparse
.
ArgumentParser
()
parser
=
Trainer
.
add_argparse_args
(
parser
)
args
=
parser
.
parse_args
([])
return
sorted
(
k
for
k
in
vars
(
args
)
if
getattr
(
opt
,
k
)
!=
getattr
(
args
,
k
))
class
WrappedDataset
(
Dataset
):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def
__init__
(
self
,
dataset
):
self
.
data
=
dataset
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
def
worker_init_fn
(
_
):
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
dataset
=
worker_info
.
dataset
worker_id
=
worker_info
.
id
if
isinstance
(
dataset
,
Txt2ImgIterableBaseDataset
):
split_size
=
dataset
.
num_records
//
worker_info
.
num_workers
# reset num_records to the true number to retain reliable length information
dataset
.
sample_ids
=
dataset
.
valid_ids
[
worker_id
*
split_size
:(
worker_id
+
1
)
*
split_size
]
current_id
=
np
.
random
.
choice
(
len
(
np
.
random
.
get_state
()[
1
]),
1
)
return
np
.
random
.
seed
(
np
.
random
.
get_state
()[
1
][
current_id
]
+
worker_id
)
else
:
return
np
.
random
.
seed
(
np
.
random
.
get_state
()[
1
][
0
]
+
worker_id
)
class
DataModuleFromConfig
(
pl
.
LightningDataModule
):
def
__init__
(
self
,
batch_size
,
train
=
None
,
validation
=
None
,
test
=
None
,
predict
=
None
,
wrap
=
False
,
num_workers
=
None
,
shuffle_test_loader
=
False
,
use_worker_init_fn
=
False
,
shuffle_val_dataloader
=
False
):
super
().
__init__
()
self
.
batch_size
=
batch_size
self
.
dataset_configs
=
dict
()
self
.
num_workers
=
num_workers
if
num_workers
is
not
None
else
batch_size
*
2
self
.
use_worker_init_fn
=
use_worker_init_fn
if
train
is
not
None
:
self
.
dataset_configs
[
"train"
]
=
train
self
.
train_dataloader
=
self
.
_train_dataloader
if
validation
is
not
None
:
self
.
dataset_configs
[
"validation"
]
=
validation
self
.
val_dataloader
=
partial
(
self
.
_val_dataloader
,
shuffle
=
shuffle_val_dataloader
)
if
test
is
not
None
:
self
.
dataset_configs
[
"test"
]
=
test
self
.
test_dataloader
=
partial
(
self
.
_test_dataloader
,
shuffle
=
shuffle_test_loader
)
if
predict
is
not
None
:
self
.
dataset_configs
[
"predict"
]
=
predict
self
.
predict_dataloader
=
self
.
_predict_dataloader
self
.
wrap
=
wrap
def
prepare_data
(
self
):
for
data_cfg
in
self
.
dataset_configs
.
values
():
instantiate_from_config
(
data_cfg
)
def
setup
(
self
,
stage
=
None
):
self
.
datasets
=
dict
(
(
k
,
instantiate_from_config
(
self
.
dataset_configs
[
k
]))
for
k
in
self
.
dataset_configs
)
if
self
.
wrap
:
for
k
in
self
.
datasets
:
self
.
datasets
[
k
]
=
WrappedDataset
(
self
.
datasets
[
k
])
def
_train_dataloader
(
self
):
is_iterable_dataset
=
isinstance
(
self
.
datasets
[
'train'
],
Txt2ImgIterableBaseDataset
)
if
is_iterable_dataset
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"train"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
shuffle
=
False
if
is_iterable_dataset
else
True
,
worker_init_fn
=
init_fn
)
def
_val_dataloader
(
self
,
shuffle
=
False
):
if
isinstance
(
self
.
datasets
[
'validation'
],
Txt2ImgIterableBaseDataset
)
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"validation"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
def
_test_dataloader
(
self
,
shuffle
=
False
):
is_iterable_dataset
=
isinstance
(
self
.
datasets
[
'train'
],
Txt2ImgIterableBaseDataset
)
if
is_iterable_dataset
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
# do not shuffle dataloader for iterable dataset
shuffle
=
shuffle
and
(
not
is_iterable_dataset
)
return
DataLoader
(
self
.
datasets
[
"test"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
,
shuffle
=
shuffle
)
def
_predict_dataloader
(
self
,
shuffle
=
False
):
if
isinstance
(
self
.
datasets
[
'predict'
],
Txt2ImgIterableBaseDataset
)
or
self
.
use_worker_init_fn
:
init_fn
=
worker_init_fn
else
:
init_fn
=
None
return
DataLoader
(
self
.
datasets
[
"predict"
],
batch_size
=
self
.
batch_size
,
num_workers
=
self
.
num_workers
,
worker_init_fn
=
init_fn
)
class
SetupCallback
(
Callback
):
def
__init__
(
self
,
resume
,
now
,
logdir
,
ckptdir
,
cfgdir
,
config
,
lightning_config
):
super
().
__init__
()
self
.
resume
=
resume
self
.
now
=
now
self
.
logdir
=
logdir
self
.
ckptdir
=
ckptdir
self
.
cfgdir
=
cfgdir
self
.
config
=
config
self
.
lightning_config
=
lightning_config
def
on_keyboard_interrupt
(
self
,
trainer
,
pl_module
):
if
trainer
.
global_rank
==
0
:
print
(
"Summoning checkpoint."
)
ckpt_path
=
os
.
path
.
join
(
self
.
ckptdir
,
"last.ckpt"
)
trainer
.
save_checkpoint
(
ckpt_path
)
def
on_pretrain_routine_start
(
self
,
trainer
,
pl_module
):
if
trainer
.
global_rank
==
0
:
# Create logdirs and save configs
os
.
makedirs
(
self
.
logdir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
ckptdir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
cfgdir
,
exist_ok
=
True
)
if
"callbacks"
in
self
.
lightning_config
:
if
'metrics_over_trainsteps_checkpoint'
in
self
.
lightning_config
[
'callbacks'
]:
os
.
makedirs
(
os
.
path
.
join
(
self
.
ckptdir
,
'trainstep_checkpoints'
),
exist_ok
=
True
)
print
(
"Project config"
)
print
(
OmegaConf
.
to_yaml
(
self
.
config
))
OmegaConf
.
save
(
self
.
config
,
os
.
path
.
join
(
self
.
cfgdir
,
"{}-project.yaml"
.
format
(
self
.
now
)))
print
(
"Lightning config"
)
print
(
OmegaConf
.
to_yaml
(
self
.
lightning_config
))
OmegaConf
.
save
(
OmegaConf
.
create
({
"lightning"
:
self
.
lightning_config
}),
os
.
path
.
join
(
self
.
cfgdir
,
"{}-lightning.yaml"
.
format
(
self
.
now
)))
else
:
# ModelCheckpoint callback created log directory --- remove it
if
not
self
.
resume
and
os
.
path
.
exists
(
self
.
logdir
):
dst
,
name
=
os
.
path
.
split
(
self
.
logdir
)
dst
=
os
.
path
.
join
(
dst
,
"child_runs"
,
name
)
os
.
makedirs
(
os
.
path
.
split
(
dst
)[
0
],
exist_ok
=
True
)
try
:
os
.
rename
(
self
.
logdir
,
dst
)
except
FileNotFoundError
:
pass
class
ImageLogger
(
Callback
):
def
__init__
(
self
,
batch_frequency
,
max_images
,
clamp
=
True
,
increase_log_steps
=
True
,
rescale
=
True
,
disabled
=
False
,
log_on_batch_idx
=
False
,
log_first_step
=
False
,
log_images_kwargs
=
None
):
super
().
__init__
()
self
.
rescale
=
rescale
self
.
batch_freq
=
batch_frequency
self
.
max_images
=
max_images
self
.
logger_log_images
=
{
pl
.
loggers
.
TestTubeLogger
:
self
.
_testtube
,
}
self
.
log_steps
=
[
2
**
n
for
n
in
range
(
int
(
np
.
log2
(
self
.
batch_freq
))
+
1
)]
if
not
increase_log_steps
:
self
.
log_steps
=
[
self
.
batch_freq
]
self
.
clamp
=
clamp
self
.
disabled
=
disabled
self
.
log_on_batch_idx
=
log_on_batch_idx
self
.
log_images_kwargs
=
log_images_kwargs
if
log_images_kwargs
else
{}
self
.
log_first_step
=
log_first_step
@
rank_zero_only
def
_testtube
(
self
,
pl_module
,
images
,
batch_idx
,
split
):
for
k
in
images
:
grid
=
torchvision
.
utils
.
make_grid
(
images
[
k
])
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
tag
=
f
"
{
split
}
/
{
k
}
"
pl_module
.
logger
.
experiment
.
add_image
(
tag
,
grid
,
global_step
=
pl_module
.
global_step
)
@
rank_zero_only
def
log_local
(
self
,
save_dir
,
split
,
images
,
global_step
,
current_epoch
,
batch_idx
):
root
=
os
.
path
.
join
(
save_dir
,
"images"
,
split
)
for
k
in
images
:
grid
=
torchvision
.
utils
.
make_grid
(
images
[
k
],
nrow
=
4
)
if
self
.
rescale
:
grid
=
(
grid
+
1.0
)
/
2.0
# -1,1 -> 0,1; c,h,w
grid
=
grid
.
transpose
(
0
,
1
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
grid
=
grid
.
numpy
()
grid
=
(
grid
*
255
).
astype
(
np
.
uint8
)
filename
=
"{}_gs-{:06}_e-{:06}_b-{:06}.png"
.
format
(
k
,
global_step
,
current_epoch
,
batch_idx
)
path
=
os
.
path
.
join
(
root
,
filename
)
os
.
makedirs
(
os
.
path
.
split
(
path
)[
0
],
exist_ok
=
True
)
Image
.
fromarray
(
grid
).
save
(
path
)
def
log_img
(
self
,
pl_module
,
batch
,
batch_idx
,
split
=
"train"
):
check_idx
=
batch_idx
if
self
.
log_on_batch_idx
else
pl_module
.
global_step
if
(
self
.
check_frequency
(
check_idx
)
and
# batch_idx % self.batch_freq == 0
hasattr
(
pl_module
,
"log_images"
)
and
callable
(
pl_module
.
log_images
)
and
self
.
max_images
>
0
):
logger
=
type
(
pl_module
.
logger
)
is_train
=
pl_module
.
training
if
is_train
:
pl_module
.
eval
()
with
torch
.
no_grad
():
images
=
pl_module
.
log_images
(
batch
,
split
=
split
,
**
self
.
log_images_kwargs
)
for
k
in
images
:
N
=
min
(
images
[
k
].
shape
[
0
],
self
.
max_images
)
images
[
k
]
=
images
[
k
][:
N
]
if
isinstance
(
images
[
k
],
torch
.
Tensor
):
images
[
k
]
=
images
[
k
].
detach
().
cpu
()
if
self
.
clamp
:
images
[
k
]
=
torch
.
clamp
(
images
[
k
],
-
1.
,
1.
)
self
.
log_local
(
pl_module
.
logger
.
save_dir
,
split
,
images
,
pl_module
.
global_step
,
pl_module
.
current_epoch
,
batch_idx
)
logger_log_images
=
self
.
logger_log_images
.
get
(
logger
,
lambda
*
args
,
**
kwargs
:
None
)
logger_log_images
(
pl_module
,
images
,
pl_module
.
global_step
,
split
)
if
is_train
:
pl_module
.
train
()
def
check_frequency
(
self
,
check_idx
):
if
((
check_idx
%
self
.
batch_freq
)
==
0
or
(
check_idx
in
self
.
log_steps
))
and
(
check_idx
>
0
or
self
.
log_first_step
):
try
:
self
.
log_steps
.
pop
(
0
)
except
IndexError
as
e
:
print
(
e
)
pass
return
True
return
False
def
on_train_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
,
dataloader_idx
):
if
not
self
.
disabled
and
(
pl_module
.
global_step
>
0
or
self
.
log_first_step
):
self
.
log_img
(
pl_module
,
batch
,
batch_idx
,
split
=
"train"
)
def
on_validation_batch_end
(
self
,
trainer
,
pl_module
,
outputs
,
batch
,
batch_idx
,
dataloader_idx
):
if
not
self
.
disabled
and
pl_module
.
global_step
>
0
:
self
.
log_img
(
pl_module
,
batch
,
batch_idx
,
split
=
"val"
)
if
hasattr
(
pl_module
,
'calibrate_grad_norm'
):
if
(
pl_module
.
calibrate_grad_norm
and
batch_idx
%
25
==
0
)
and
batch_idx
>
0
:
self
.
log_gradients
(
trainer
,
pl_module
,
batch_idx
=
batch_idx
)
class
CUDACallback
(
Callback
):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def
on_train_epoch_start
(
self
,
trainer
,
pl_module
):
# Reset the memory use counter
torch
.
cuda
.
reset_peak_memory_stats
(
trainer
.
root_gpu
)
torch
.
cuda
.
synchronize
(
trainer
.
root_gpu
)
self
.
start_time
=
time
.
time
()
def
on_train_epoch_end
(
self
,
trainer
,
pl_module
,
outputs
):
torch
.
cuda
.
synchronize
(
trainer
.
root_gpu
)
max_memory
=
torch
.
cuda
.
max_memory_allocated
(
trainer
.
root_gpu
)
/
2
**
20
epoch_time
=
time
.
time
()
-
self
.
start_time
try
:
max_memory
=
trainer
.
training_type_plugin
.
reduce
(
max_memory
)
epoch_time
=
trainer
.
training_type_plugin
.
reduce
(
epoch_time
)
rank_zero_info
(
f
"Average Epoch time:
{
epoch_time
:.
2
f
}
seconds"
)
rank_zero_info
(
f
"Average Peak memory
{
max_memory
:.
2
f
}
MiB"
)
except
AttributeError
:
pass
if
__name__
==
"__main__"
:
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now
=
datetime
.
datetime
.
now
().
strftime
(
"%Y-%m-%dT%H-%M-%S"
)
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys
.
path
.
append
(
os
.
getcwd
())
parser
=
get_parser
()
parser
=
Trainer
.
add_argparse_args
(
parser
)
opt
,
unknown
=
parser
.
parse_known_args
()
if
opt
.
name
and
opt
.
resume
:
raise
ValueError
(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if
opt
.
resume
:
if
not
os
.
path
.
exists
(
opt
.
resume
):
raise
ValueError
(
"Cannot find {}"
.
format
(
opt
.
resume
))
if
os
.
path
.
isfile
(
opt
.
resume
):
paths
=
opt
.
resume
.
split
(
"/"
)
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir
=
"/"
.
join
(
paths
[:
-
2
])
ckpt
=
opt
.
resume
else
:
assert
os
.
path
.
isdir
(
opt
.
resume
),
opt
.
resume
logdir
=
opt
.
resume
.
rstrip
(
"/"
)
ckpt
=
os
.
path
.
join
(
logdir
,
"checkpoints"
,
"last.ckpt"
)
opt
.
resume_from_checkpoint
=
ckpt
base_configs
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
logdir
,
"configs/*.yaml"
)))
opt
.
base
=
base_configs
+
opt
.
base
_tmp
=
logdir
.
split
(
"/"
)
nowname
=
_tmp
[
-
1
]
else
:
if
opt
.
name
:
name
=
"_"
+
opt
.
name
elif
opt
.
base
:
cfg_fname
=
os
.
path
.
split
(
opt
.
base
[
0
])[
-
1
]
cfg_name
=
os
.
path
.
splitext
(
cfg_fname
)[
0
]
name
=
"_"
+
cfg_name
else
:
name
=
""
nowname
=
now
+
name
+
opt
.
postfix
logdir
=
os
.
path
.
join
(
opt
.
logdir
,
nowname
)
ckptdir
=
os
.
path
.
join
(
logdir
,
"checkpoints"
)
cfgdir
=
os
.
path
.
join
(
logdir
,
"configs"
)
seed_everything
(
opt
.
seed
)
try
:
# init and save configs
configs
=
[
OmegaConf
.
load
(
cfg
)
for
cfg
in
opt
.
base
]
cli
=
OmegaConf
.
from_dotlist
(
unknown
)
config
=
OmegaConf
.
merge
(
*
configs
,
cli
)
lightning_config
=
config
.
pop
(
"lightning"
,
OmegaConf
.
create
())
# merge trainer cli with config
trainer_config
=
lightning_config
.
get
(
"trainer"
,
OmegaConf
.
create
())
# default to ddp
trainer_config
[
"accelerator"
]
=
"ddp"
for
k
in
nondefault_trainer_args
(
opt
):
trainer_config
[
k
]
=
getattr
(
opt
,
k
)
if
not
"gpus"
in
trainer_config
:
del
trainer_config
[
"accelerator"
]
cpu
=
True
else
:
gpuinfo
=
trainer_config
[
"gpus"
]
print
(
f
"Running on GPUs
{
gpuinfo
}
"
)
cpu
=
False
trainer_opt
=
argparse
.
Namespace
(
**
trainer_config
)
lightning_config
.
trainer
=
trainer_config
# model
model
=
instantiate_from_config
(
config
.
model
)
# trainer and callbacks
trainer_kwargs
=
dict
()
# default logger configs
default_logger_cfgs
=
{
"wandb"
:
{
"target"
:
"pytorch_lightning.loggers.WandbLogger"
,
"params"
:
{
"name"
:
nowname
,
"save_dir"
:
logdir
,
"offline"
:
opt
.
debug
,
"id"
:
nowname
,
}
},
"testtube"
:
{
"target"
:
"pytorch_lightning.loggers.TestTubeLogger"
,
"params"
:
{
"name"
:
"testtube"
,
"save_dir"
:
logdir
,
}
},
}
default_logger_cfg
=
default_logger_cfgs
[
"testtube"
]
if
"logger"
in
lightning_config
:
logger_cfg
=
lightning_config
.
logger
else
:
logger_cfg
=
OmegaConf
.
create
()
logger_cfg
=
OmegaConf
.
merge
(
default_logger_cfg
,
logger_cfg
)
trainer_kwargs
[
"logger"
]
=
instantiate_from_config
(
logger_cfg
)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg
=
{
"target"
:
"pytorch_lightning.callbacks.ModelCheckpoint"
,
"params"
:
{
"dirpath"
:
ckptdir
,
"filename"
:
"{epoch:06}"
,
"verbose"
:
True
,
"save_last"
:
True
,
}
}
if
hasattr
(
model
,
"monitor"
):
print
(
f
"Monitoring
{
model
.
monitor
}
as checkpoint metric."
)
default_modelckpt_cfg
[
"params"
][
"monitor"
]
=
model
.
monitor
default_modelckpt_cfg
[
"params"
][
"save_top_k"
]
=
3
if
"modelcheckpoint"
in
lightning_config
:
modelckpt_cfg
=
lightning_config
.
modelcheckpoint
else
:
modelckpt_cfg
=
OmegaConf
.
create
()
modelckpt_cfg
=
OmegaConf
.
merge
(
default_modelckpt_cfg
,
modelckpt_cfg
)
print
(
f
"Merged modelckpt-cfg:
\n
{
modelckpt_cfg
}
"
)
if
version
.
parse
(
pl
.
__version__
)
<
version
.
parse
(
'1.4.0'
):
trainer_kwargs
[
"checkpoint_callback"
]
=
instantiate_from_config
(
modelckpt_cfg
)
# add callback which sets up log directory
default_callbacks_cfg
=
{
"setup_callback"
:
{
"target"
:
"main.SetupCallback"
,
"params"
:
{
"resume"
:
opt
.
resume
,
"now"
:
now
,
"logdir"
:
logdir
,
"ckptdir"
:
ckptdir
,
"cfgdir"
:
cfgdir
,
"config"
:
config
,
"lightning_config"
:
lightning_config
,
}
},
"image_logger"
:
{
"target"
:
"main.ImageLogger"
,
"params"
:
{
"batch_frequency"
:
750
,
"max_images"
:
4
,
"clamp"
:
True
}
},
"learning_rate_logger"
:
{
"target"
:
"main.LearningRateMonitor"
,
"params"
:
{
"logging_interval"
:
"step"
,
# "log_momentum": True
}
},
"cuda_callback"
:
{
"target"
:
"main.CUDACallback"
},
}
if
version
.
parse
(
pl
.
__version__
)
>=
version
.
parse
(
'1.4.0'
):
default_callbacks_cfg
.
update
({
'checkpoint_callback'
:
modelckpt_cfg
})
if
"callbacks"
in
lightning_config
:
callbacks_cfg
=
lightning_config
.
callbacks
else
:
callbacks_cfg
=
OmegaConf
.
create
()
if
'metrics_over_trainsteps_checkpoint'
in
callbacks_cfg
:
print
(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict
=
{
'metrics_over_trainsteps_checkpoint'
:
{
"target"
:
'pytorch_lightning.callbacks.ModelCheckpoint'
,
'params'
:
{
"dirpath"
:
os
.
path
.
join
(
ckptdir
,
'trainstep_checkpoints'
),
"filename"
:
"{epoch:06}-{step:09}"
,
"verbose"
:
True
,
'save_top_k'
:
-
1
,
'every_n_train_steps'
:
10000
,
'save_weights_only'
:
True
}
}
}
default_callbacks_cfg
.
update
(
default_metrics_over_trainsteps_ckpt_dict
)
callbacks_cfg
=
OmegaConf
.
merge
(
default_callbacks_cfg
,
callbacks_cfg
)
if
'ignore_keys_callback'
in
callbacks_cfg
and
hasattr
(
trainer_opt
,
'resume_from_checkpoint'
):
callbacks_cfg
.
ignore_keys_callback
.
params
[
'ckpt_path'
]
=
trainer_opt
.
resume_from_checkpoint
elif
'ignore_keys_callback'
in
callbacks_cfg
:
del
callbacks_cfg
[
'ignore_keys_callback'
]
trainer_kwargs
[
"callbacks"
]
=
[
instantiate_from_config
(
callbacks_cfg
[
k
])
for
k
in
callbacks_cfg
]
trainer
=
Trainer
.
from_argparse_args
(
trainer_opt
,
**
trainer_kwargs
)
trainer
.
logdir
=
logdir
###
# data
data
=
instantiate_from_config
(
config
.
data
)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data
.
prepare_data
()
data
.
setup
()
print
(
"#### Data #####"
)
for
k
in
data
.
datasets
:
print
(
f
"
{
k
}
,
{
data
.
datasets
[
k
].
__class__
.
__name__
}
,
{
len
(
data
.
datasets
[
k
])
}
"
)
# configure learning rate
bs
,
base_lr
=
config
.
data
.
params
.
batch_size
,
config
.
model
.
base_learning_rate
if
not
cpu
:
ngpu
=
len
(
lightning_config
.
trainer
.
gpus
.
strip
(
","
).
split
(
','
))
else
:
ngpu
=
1
if
'accumulate_grad_batches'
in
lightning_config
.
trainer
:
accumulate_grad_batches
=
lightning_config
.
trainer
.
accumulate_grad_batches
else
:
accumulate_grad_batches
=
1
print
(
f
"accumulate_grad_batches =
{
accumulate_grad_batches
}
"
)
lightning_config
.
trainer
.
accumulate_grad_batches
=
accumulate_grad_batches
if
opt
.
scale_lr
:
model
.
learning_rate
=
accumulate_grad_batches
*
ngpu
*
bs
*
base_lr
print
(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.
format
(
model
.
learning_rate
,
accumulate_grad_batches
,
ngpu
,
bs
,
base_lr
))
else
:
model
.
learning_rate
=
base_lr
print
(
"++++ NOT USING LR SCALING ++++"
)
print
(
f
"Setting learning rate to
{
model
.
learning_rate
:.
2
e
}
"
)
# allow checkpointing via USR1
def
melk
(
*
args
,
**
kwargs
):
# run all checkpoint hooks
if
trainer
.
global_rank
==
0
:
print
(
"Summoning checkpoint."
)
ckpt_path
=
os
.
path
.
join
(
ckptdir
,
"last.ckpt"
)
trainer
.
save_checkpoint
(
ckpt_path
)
def
divein
(
*
args
,
**
kwargs
):
if
trainer
.
global_rank
==
0
:
import
pudb
;
pudb
.
set_trace
()
import
signal
signal
.
signal
(
signal
.
SIGUSR1
,
melk
)
signal
.
signal
(
signal
.
SIGUSR2
,
divein
)
# run
if
opt
.
train
:
try
:
trainer
.
fit
(
model
,
data
)
except
Exception
:
melk
()
raise
if
not
opt
.
no_test
and
not
trainer
.
interrupted
:
trainer
.
test
(
model
,
data
)
except
Exception
:
if
opt
.
debug
and
trainer
.
global_rank
==
0
:
try
:
import
pudb
as
debugger
except
ImportError
:
import
pdb
as
debugger
debugger
.
post_mortem
()
raise
finally
:
# move newly created debug project to debug_runs
if
opt
.
debug
and
not
opt
.
resume
and
trainer
.
global_rank
==
0
:
dst
,
name
=
os
.
path
.
split
(
logdir
)
dst
=
os
.
path
.
join
(
dst
,
"debug_runs"
,
name
)
os
.
makedirs
(
os
.
path
.
split
(
dst
)[
0
],
exist_ok
=
True
)
os
.
rename
(
logdir
,
dst
)
if
trainer
.
global_rank
==
0
:
print
(
trainer
.
profiler
.
summary
())
models/first_stage_models/kl-f16/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
16
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
16
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
16
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
6
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/first_stage_models/kl-f32/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
64
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
64
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
1
-
2
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
-
16
-
8
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
6
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/first_stage_models/kl-f4/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
3
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
10
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/first_stage_models/kl-f8/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
monitor
:
val/rec_loss
embed_dim
:
4
lossconfig
:
target
:
ldm.modules.losses.LPIPSWithDiscriminator
params
:
disc_start
:
50001
kl_weight
:
1.0e-06
disc_weight
:
0.5
ddconfig
:
double_z
:
true
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
4
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/first_stage_models/vq-f16/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
8
n_embed
:
16384
ddconfig
:
double_z
:
false
z_channels
:
8
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
16
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_start
:
250001
disc_weight
:
0.75
disc_num_layers
:
2
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
14
num_workers
:
20
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/first_stage_models/vq-f4-noattn/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
3
n_embed
:
8192
monitor
:
val/rec_loss
ddconfig
:
attn_type
:
none
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_start
:
11
disc_weight
:
0.75
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
8
num_workers
:
12
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
crop_size
:
256
models/first_stage_models/vq-f4/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
3
n_embed
:
8192
monitor
:
val/rec_loss
ddconfig
:
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_start
:
0
disc_weight
:
0.75
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
8
num_workers
:
16
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
crop_size
:
256
models/first_stage_models/vq-f8-n256/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
4
n_embed
:
256
monitor
:
val/rec_loss
ddconfig
:
double_z
:
false
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
32
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_start
:
250001
disc_weight
:
0.75
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
10
num_workers
:
20
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/first_stage_models/vq-f8/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
4.5e-06
target
:
ldm.models.autoencoder.VQModel
params
:
embed_dim
:
4
n_embed
:
16384
monitor
:
val/rec_loss
ddconfig
:
double_z
:
false
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
32
dropout
:
0.0
lossconfig
:
target
:
taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params
:
disc_conditional
:
false
disc_in_channels
:
3
disc_num_layers
:
2
disc_start
:
1
disc_weight
:
0.6
codebook_weight
:
1.0
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
10
num_workers
:
20
wrap
:
true
train
:
target
:
ldm.data.openimages.FullOpenImagesTrain
params
:
size
:
384
crop_size
:
256
validation
:
target
:
ldm.data.openimages.FullOpenImagesValidation
params
:
size
:
384
crop_size
:
256
models/ldm/bsr_sr/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
1.0e-06
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.0015
linear_end
:
0.0155
log_every_t
:
100
timesteps
:
1000
loss_type
:
l2
first_stage_key
:
image
cond_stage_key
:
LR_image
image_size
:
64
channels
:
3
concat_mode
:
true
cond_stage_trainable
:
false
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
64
in_channels
:
6
out_channels
:
3
model_channels
:
160
attention_resolutions
:
-
16
-
8
num_res_blocks
:
2
channel_mult
:
-
1
-
2
-
2
-
4
num_head_channels
:
32
first_stage_config
:
target
:
ldm.models.autoencoder.VQModelInterface
params
:
embed_dim
:
3
n_embed
:
8192
monitor
:
val/rec_loss
ddconfig
:
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
torch.nn.Identity
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
64
wrap
:
false
num_workers
:
12
train
:
target
:
ldm.data.openimages.SuperresOpenImagesAdvancedTrain
params
:
size
:
256
degradation
:
bsrgan_light
downscale_f
:
4
min_crop_f
:
0.5
max_crop_f
:
1.0
random_crop
:
true
validation
:
target
:
ldm.data.openimages.SuperresOpenImagesAdvancedValidation
params
:
size
:
256
degradation
:
bsrgan_light
downscale_f
:
4
min_crop_f
:
0.5
max_crop_f
:
1.0
random_crop
:
true
models/ldm/celeba256/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
2.0e-06
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.0015
linear_end
:
0.0195
num_timesteps_cond
:
1
log_every_t
:
200
timesteps
:
1000
first_stage_key
:
image
cond_stage_key
:
class_label
image_size
:
64
channels
:
3
cond_stage_trainable
:
false
concat_mode
:
false
monitor
:
val/loss
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
64
in_channels
:
3
out_channels
:
3
model_channels
:
224
attention_resolutions
:
-
8
-
4
-
2
num_res_blocks
:
2
channel_mult
:
-
1
-
2
-
3
-
4
num_head_channels
:
32
first_stage_config
:
target
:
ldm.models.autoencoder.VQModelInterface
params
:
embed_dim
:
3
n_embed
:
8192
ddconfig
:
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
__is_unconditional__
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
48
num_workers
:
5
wrap
:
false
train
:
target
:
ldm.data.faceshq.CelebAHQTrain
params
:
size
:
256
validation
:
target
:
ldm.data.faceshq.CelebAHQValidation
params
:
size
:
256
models/ldm/cin256/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
1.0e-06
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.0015
linear_end
:
0.0195
num_timesteps_cond
:
1
log_every_t
:
200
timesteps
:
1000
first_stage_key
:
image
cond_stage_key
:
class_label
image_size
:
32
channels
:
4
cond_stage_trainable
:
true
conditioning_key
:
crossattn
monitor
:
val/loss_simple_ema
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
32
in_channels
:
4
out_channels
:
4
model_channels
:
256
attention_resolutions
:
-
4
-
2
-
1
num_res_blocks
:
2
channel_mult
:
-
1
-
2
-
4
num_head_channels
:
32
use_spatial_transformer
:
true
transformer_depth
:
1
context_dim
:
512
first_stage_config
:
target
:
ldm.models.autoencoder.VQModelInterface
params
:
embed_dim
:
4
n_embed
:
16384
ddconfig
:
double_z
:
false
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
-
32
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
ldm.modules.encoders.modules.ClassEmbedder
params
:
embed_dim
:
512
key
:
class_label
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
64
num_workers
:
12
wrap
:
false
train
:
target
:
ldm.data.imagenet.ImageNetTrain
params
:
config
:
size
:
256
validation
:
target
:
ldm.data.imagenet.ImageNetValidation
params
:
config
:
size
:
256
models/ldm/ffhq256/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
2.0e-06
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.0015
linear_end
:
0.0195
num_timesteps_cond
:
1
log_every_t
:
200
timesteps
:
1000
first_stage_key
:
image
cond_stage_key
:
class_label
image_size
:
64
channels
:
3
cond_stage_trainable
:
false
concat_mode
:
false
monitor
:
val/loss
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
64
in_channels
:
3
out_channels
:
3
model_channels
:
224
attention_resolutions
:
-
8
-
4
-
2
num_res_blocks
:
2
channel_mult
:
-
1
-
2
-
3
-
4
num_head_channels
:
32
first_stage_config
:
target
:
ldm.models.autoencoder.VQModelInterface
params
:
embed_dim
:
3
n_embed
:
8192
ddconfig
:
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
__is_unconditional__
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
42
num_workers
:
5
wrap
:
false
train
:
target
:
ldm.data.faceshq.FFHQTrain
params
:
size
:
256
validation
:
target
:
ldm.data.faceshq.FFHQValidation
params
:
size
:
256
models/ldm/inpainting_big/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
1.0e-06
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.0015
linear_end
:
0.0205
log_every_t
:
100
timesteps
:
1000
loss_type
:
l1
first_stage_key
:
image
cond_stage_key
:
masked_image
image_size
:
64
channels
:
3
concat_mode
:
true
monitor
:
val/loss
scheduler_config
:
target
:
ldm.lr_scheduler.LambdaWarmUpCosineScheduler
params
:
verbosity_interval
:
0
warm_up_steps
:
1000
max_decay_steps
:
50000
lr_start
:
0.001
lr_max
:
0.1
lr_min
:
0.0001
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
64
in_channels
:
7
out_channels
:
3
model_channels
:
256
attention_resolutions
:
-
8
-
4
-
2
num_res_blocks
:
2
channel_mult
:
-
1
-
2
-
3
-
4
num_heads
:
8
resblock_updown
:
true
first_stage_config
:
target
:
ldm.models.autoencoder.VQModelInterface
params
:
embed_dim
:
3
n_embed
:
8192
monitor
:
val/rec_loss
ddconfig
:
attn_type
:
none
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
ldm.modules.losses.contperceptual.DummyLoss
cond_stage_config
:
__is_first_stage__
models/ldm/layout2img-openimages256/config.yaml
0 → 100644
View file @
86685e45
model
:
base_learning_rate
:
2.0e-06
target
:
ldm.models.diffusion.ddpm.LatentDiffusion
params
:
linear_start
:
0.0015
linear_end
:
0.0205
log_every_t
:
100
timesteps
:
1000
loss_type
:
l1
first_stage_key
:
image
cond_stage_key
:
coordinates_bbox
image_size
:
64
channels
:
3
conditioning_key
:
crossattn
cond_stage_trainable
:
true
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
image_size
:
64
in_channels
:
3
out_channels
:
3
model_channels
:
128
attention_resolutions
:
-
8
-
4
-
2
num_res_blocks
:
2
channel_mult
:
-
1
-
2
-
3
-
4
num_head_channels
:
32
use_spatial_transformer
:
true
transformer_depth
:
3
context_dim
:
512
first_stage_config
:
target
:
ldm.models.autoencoder.VQModelInterface
params
:
embed_dim
:
3
n_embed
:
8192
monitor
:
val/rec_loss
ddconfig
:
double_z
:
false
z_channels
:
3
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
num_res_blocks
:
2
attn_resolutions
:
[]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
ldm.modules.encoders.modules.BERTEmbedder
params
:
n_embed
:
512
n_layer
:
16
vocab_size
:
8192
max_seq_len
:
92
use_tokenizer
:
false
monitor
:
val/loss_simple_ema
data
:
target
:
main.DataModuleFromConfig
params
:
batch_size
:
24
wrap
:
false
num_workers
:
10
train
:
target
:
ldm.data.openimages.OpenImagesBBoxTrain
params
:
size
:
256
validation
:
target
:
ldm.data.openimages.OpenImagesBBoxValidation
params
:
size
:
256
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment