Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
df2e145e
Commit
df2e145e
authored
Jun 29, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into unify_resnet
parents
046dc430
8cba133f
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
990 additions
and
1492 deletions
+990
-1492
examples/experimental/train_latent_text_to_image.py
examples/experimental/train_latent_text_to_image.py
+0
-0
examples/train_unconditional.py
examples/train_unconditional.py
+83
-58
setup.py
setup.py
+4
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/dependency_versions_table.py
src/diffusers/dependency_versions_table.py
+1
-0
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+35
-19
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+636
-0
src/diffusers/pipelines/latent_diffusion/__init__.py
src/diffusers/pipelines/latent_diffusion/__init__.py
+1
-1
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+0
-849
src/diffusers/pipelines/latent_diffusion_uncond/__init__.py
src/diffusers/pipelines/latent_diffusion_uncond/__init__.py
+1
-1
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+0
-561
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+12
-0
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+1
-2
src/diffusers/utils/model_card_template.md
src/diffusers/utils/model_card_template.md
+50
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+164
-0
No files found.
examples/train_latent_text_to_image.py
→
examples/
experimental/
train_latent_text_to_image.py
View file @
df2e145e
File moved
examples/train_unconditional.py
View file @
df2e145e
...
@@ -4,14 +4,13 @@ import os
...
@@ -4,14 +4,13 @@ import os
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
accelerate.logging
import
get_logger
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DD
P
MPipeline
,
DD
P
MScheduler
,
UNetModel
from
diffusers
import
DD
I
MPipeline
,
DD
I
MScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
...
@@ -24,11 +23,12 @@ from torchvision.transforms import (
...
@@ -24,11 +23,12 @@ from torchvision.transforms import (
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
logger
=
logging
.
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
def
main
(
args
):
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
)
model
=
UNetModel
(
model
=
UNetModel
(
attn_resolutions
=
(
16
,),
attn_resolutions
=
(
16
,),
...
@@ -39,8 +39,14 @@ def main(args):
...
@@ -39,8 +39,14 @@ def main(args):
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resolution
=
args
.
resolution
,
resolution
=
args
.
resolution
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
noise_scheduler
=
DDIMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
lr
=
args
.
learning_rate
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
weight_decay
=
args
.
adam_weight_decay
,
eps
=
args
.
adam_epsilon
,
)
augmentations
=
Compose
(
augmentations
=
Compose
(
[
[
...
@@ -58,12 +64,12 @@ def main(args):
...
@@ -58,12 +64,12 @@ def main(args):
return
{
"input"
:
images
}
return
{
"input"
:
images
}
dataset
.
set_transform
(
transforms
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
train_
batch_size
,
shuffle
=
True
)
lr_scheduler
=
get_scheduler
(
lr_scheduler
=
get_scheduler
(
"linear"
,
args
.
lr_scheduler
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_warmup_steps
=
args
.
lr_
warmup_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
)
...
@@ -76,15 +82,19 @@ def main(args):
...
@@ -76,15 +82,19 @@ def main(args):
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
if
accelerator
.
is_main_process
:
run
=
os
.
path
.
split
(
__file__
)[
-
1
].
split
(
"."
)[
0
]
accelerator
.
init_trackers
(
run
)
# Train!
# Train!
is_distributed
=
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
is_distributed
=
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
is_distributed
else
1
world_size
=
torch
.
distributed
.
get_world_size
()
if
is_distributed
else
1
total_train_batch_size
=
args
.
batch_size
*
args
.
gradient_accumulation_steps
*
world_size
total_train_batch_size
=
args
.
train_
batch_size
*
args
.
gradient_accumulation_steps
*
world_size
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
batch_size
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
train_
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
...
@@ -92,65 +102,71 @@ def main(args):
...
@@ -92,65 +102,71 @@ def main(args):
global_step
=
0
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
progress_bar
=
tqdm
(
total
=
len
(
train_dataloader
),
disable
=
not
accelerator
.
is_local_main_process
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
progress_bar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
clean_images
=
batch
[
"input"
]
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
bsz
=
clean_images
.
shape
[
0
]
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
# add noise onto the clean images according to the noise magnitude at each timestep
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
# (this is the forward diffusion process)
noisy_images
=
noise_scheduler
.
training_step
(
clean_images
,
noise_samples
,
timesteps
)
noisy_images
=
noise_scheduler
.
add_noise
(
clean_images
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise residual
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
else
:
optimizer
.
step
()
output
=
model
(
noisy_images
,
timesteps
)
lr_scheduler
.
step
()
# predict the noise residual
ema_model
.
step
(
model
,
global_step
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
optimizer
.
zero_grad
()
loss
=
loss
/
args
.
gradient_accumulation_steps
pbar
.
update
(
1
)
accelerator
.
backward
(
loss
)
pbar
.
set_postfix
(
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
],
ema_decay
=
ema_model
.
decay
optimizer
.
step
()
)
lr_scheduler
.
step
()
global_step
+=
1
ema_model
.
step
(
model
,
global_step
)
optimizer
.
zero_grad
()
progress_bar
.
update
(
1
)
progress_bar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
],
ema_decay
=
ema_model
.
decay
)
accelerator
.
log
(
{
"train_loss"
:
loss
.
detach
().
item
(),
"epoch"
:
epoch
,
"ema_decay"
:
ema_model
.
decay
,
"step"
:
global_step
,
},
step
=
global_step
,
)
global_step
+=
1
progress_bar
.
close
()
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPMPipeline
(
pipeline
=
DDIMPipeline
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
,
)
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
image
s
=
pipeline
(
generator
=
generator
,
batch_size
=
args
.
eval_batch_size
,
num_inference_steps
=
50
)
# process image to PIL
# denormalize the images and save to tensorboard
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
images_processed
=
(
images
.
cpu
()
+
1.0
)
*
127.5
image_processed
=
(
image_processed
+
1.0
)
*
127.5
images_processed
=
images_processed
.
clamp
(
0
,
255
).
type
(
torch
.
uint8
).
numpy
()
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
accelerator
.
trackers
[
0
].
writer
.
add_images
(
"test_samples"
,
images_processed
,
epoch
)
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
# save the model
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
@@ -159,6 +175,8 @@ def main(args):
...
@@ -159,6 +175,8 @@ def main(args):
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
accelerator
.
end_training
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
...
@@ -167,18 +185,25 @@ if __name__ == "__main__":
...
@@ -167,18 +185,25 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"ddpm-model"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"ddpm-model"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--eval_batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"cosine"
)
parser
.
add_argument
(
"--lr_warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--adam_beta1"
,
type
=
float
,
default
=
0.95
)
parser
.
add_argument
(
"--adam_beta2"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--adam_weight_decay"
,
type
=
float
,
default
=
1e-6
)
parser
.
add_argument
(
"--adam_epsilon"
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.999
9
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_private_repo"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_private_repo"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--logging_dir"
,
type
=
str
,
default
=
"logs"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mixed_precision"
,
"--mixed_precision"
,
type
=
str
,
type
=
str
,
...
...
setup.py
View file @
df2e145e
...
@@ -87,6 +87,8 @@ _deps = [
...
@@ -87,6 +87,8 @@ _deps = [
"regex!=2019.12.17"
,
"regex!=2019.12.17"
,
"requests"
,
"requests"
,
"torch>=1.4"
,
"torch>=1.4"
,
"tensorboard"
,
"modelcards=0.1.4"
]
]
# this is a lookup table with items like:
# this is a lookup table with items like:
...
@@ -172,6 +174,8 @@ install_requires = [
...
@@ -172,6 +174,8 @@ install_requires = [
deps
[
"requests"
],
deps
[
"requests"
],
deps
[
"torch"
],
deps
[
"torch"
],
deps
[
"Pillow"
],
deps
[
"Pillow"
],
deps
[
"tensorboard"
],
deps
[
"modelcards"
],
]
]
setup
(
setup
(
...
...
src/diffusers/__init__.py
View file @
df2e145e
...
@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
...
@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
AutoencoderKL
,
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
,
VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
(
from
.pipelines
import
(
BDDMPipeline
,
BDDMPipeline
,
...
...
src/diffusers/dependency_versions_table.py
View file @
df2e145e
...
@@ -13,4 +13,5 @@ deps = {
...
@@ -13,4 +13,5 @@ deps = {
"regex"
:
"regex!=2019.12.17"
,
"regex"
:
"regex!=2019.12.17"
,
"requests"
:
"requests"
,
"requests"
:
"requests"
,
"torch"
:
"torch>=1.4"
,
"torch"
:
"torch>=1.4"
,
"tensorboard"
:
"tensorboard"
,
}
}
src/diffusers/hub_utils.py
View file @
df2e145e
...
@@ -19,9 +19,9 @@ import shutil
...
@@ -19,9 +19,9 @@ import shutil
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Optional
import
yaml
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
modelcards
import
CardData
,
ModelCard
from
.utils
import
logging
from
.utils
import
logging
...
@@ -29,10 +29,7 @@ from .utils import logging
...
@@ -29,10 +29,7 @@ from .utils import logging
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
AUTOGENERATED_TRAINER_COMMENT
=
"""
MODEL_CARD_TEMPLATE_PATH
=
Path
(
__file__
).
parent
/
"utils"
/
"model_card_template.md"
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
def
get_full_repo_name
(
model_id
:
str
,
organization
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
):
def
get_full_repo_name
(
model_id
:
str
,
organization
:
Optional
[
str
]
=
None
,
token
:
Optional
[
str
]
=
None
):
...
@@ -152,17 +149,36 @@ def create_model_card(args, model_name):
...
@@ -152,17 +149,36 @@ def create_model_card(args, model_name):
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
args
.
local_rank
not
in
[
-
1
,
0
]:
return
return
# TODO: replace this placeholder model card generation
repo_name
=
get_full_repo_name
(
model_name
,
token
=
args
.
hub_token
)
model_card
=
""
model_card
=
ModelCard
.
from_template
(
metadata
=
{
"license"
:
"apache-2.0"
,
"tags"
:
[
"pytorch"
,
"diffusers"
]}
card_data
=
CardData
(
# Card metadata object that will be converted to YAML block
metadata
=
yaml
.
dump
(
metadata
,
sort_keys
=
False
)
language
=
"en"
,
if
len
(
metadata
)
>
0
:
license
=
"apache-2.0"
,
model_card
=
f
"---
\n
{
metadata
}
---
\n
"
library_name
=
"diffusers"
,
tags
=
[],
model_card
+=
AUTOGENERATED_TRAINER_COMMENT
datasets
=
args
.
dataset
,
metrics
=
[],
model_card
+=
f
"
\n
#
{
model_name
}
\n\n
"
),
template_path
=
MODEL_CARD_TEMPLATE_PATH
,
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"README.md"
),
"w"
)
as
f
:
model_name
=
model_name
,
f
.
write
(
model_card
)
repo_name
=
repo_name
,
dataset_name
=
args
.
dataset
,
learning_rate
=
args
.
learning_rate
,
train_batch_size
=
args
.
train_batch_size
,
eval_batch_size
=
args
.
eval_batch_size
,
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
,
adam_beta1
=
args
.
adam_beta1
,
adam_beta2
=
args
.
adam_beta2
,
adam_weight_decay
=
args
.
adam_weight_decay
,
adam_epsilon
=
args
.
adam_epsilon
,
lr_scheduler
=
args
.
lr_scheduler
,
lr_warmup_steps
=
args
.
lr_warmup_steps
,
ema_inv_gamma
=
args
.
ema_inv_gamma
,
ema_power
=
args
.
ema_power
,
ema_max_decay
=
args
.
ema_max_decay
,
mixed_precision
=
args
.
mixed_precision
,
)
card_path
=
os
.
path
.
join
(
args
.
output_dir
,
"README.md"
)
model_card
.
save
(
card_path
)
src/diffusers/models/__init__.py
View file @
df2e145e
...
@@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel
...
@@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_rl
import
TemporalUNet
from
.unet_rl
import
TemporalUNet
from
.unet_sde_score_estimation
import
NCSNpp
from
.unet_sde_score_estimation
import
NCSNpp
from
.vae
import
AutoencoderKL
,
VQModel
src/diffusers/models/vae.py
0 → 100644
View file @
df2e145e
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.resnet
import
Downsample
,
Upsample
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
):
# reshape z -> (batch, height, width, channel) and flatten
z
=
z
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
self
.
embedding
.
weight
.
t
())
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
z_channels
,
embed_dim
,
1
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
def
forward
(
self
,
x
):
h
=
self
.
encode
(
x
)
dec
=
self
.
decode
(
h
)
return
dec
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
def
decode
(
self
,
z
):
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
x
,
sample_posterior
=
False
):
posterior
=
self
.
encode
(
x
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
src/diffusers/pipelines/latent_diffusion/__init__.py
View file @
df2e145e
...
@@ -2,4 +2,4 @@ from ...utils import is_transformers_available
...
@@ -2,4 +2,4 @@ from ...utils import is_transformers_available
if
is_transformers_available
():
if
is_transformers_available
():
from
.pipeline_latent_diffusion
import
AutoencoderKL
,
LatentDiffusionPipeline
,
LDMBertModel
from
.pipeline_latent_diffusion
import
LatentDiffusionPipeline
,
LDMBertModel
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
df2e145e
import
math
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput
...
@@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
...configuration_utils
import
ConfigMixin
from
...modeling_utils
import
ModelMixin
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
...
@@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
return
sequence_output
return
sequence_output
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
,
temp
=
None
,
rescale_logits
=
False
,
return_logits
=
False
):
assert
temp
is
None
or
temp
==
1.0
,
"Only for interface compatible with Gumbel"
assert
rescale_logits
==
False
,
"Only for interface compatible with Gumbel"
assert
return_logits
==
False
,
"Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z
=
rearrange
(
z
,
"b c h w -> b h w c"
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
rearrange
(
self
.
embedding
.
weight
,
"n d -> d n"
))
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
rearrange
(
z_q
,
"b h w c -> b c h w"
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
def
decode
(
self
,
z
):
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
input
,
sample_posterior
=
True
):
posterior
=
self
.
encode
(
input
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/pipelines/latent_diffusion_uncond/__init__.py
View file @
df2e145e
from
.pipeline_latent_diffusion_uncond
import
LatentDiffusionUncondPipeline
from
.pipeline_latent_diffusion_uncond
import
LatentDiffusionUncondPipeline
,
VQModel
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
df2e145e
import
math
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
tqdm
import
tqdm
from
...configuration_utils
import
ConfigMixin
from
...modeling_utils
import
ModelMixin
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
):
# reshape z -> (batch, height, width, channel) and flatten
z
=
z
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
self
.
embedding
.
weight
.
t
())
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
df2e145e
...
@@ -71,6 +71,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -71,6 +71,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
self
.
betas
=
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float32
)
elif
beta_schedule
==
"scaled_linear"
:
# this schedule is very specific to the latent diffusion model.
self
.
betas
=
np
.
linspace
(
beta_start
**
0.5
,
beta_end
**
0.5
,
timesteps
,
dtype
=
np
.
float32
)
**
2
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# Glide cosine schedule
# Glide cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
self
.
betas
=
betas_for_alpha_bar
(
timesteps
)
...
@@ -142,5 +145,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,5 +145,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
self
.
match_shape
(
sqrt_one_minus_alpha_prod
,
original_samples
)
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
config
.
timesteps
return
self
.
config
.
timesteps
src/diffusers/schedulers/scheduling_ddpm.py
View file @
df2e145e
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -142,7 +141,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -142,7 +141,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
...
...
src/diffusers/utils/model_card_template.md
0 → 100644
View file @
df2e145e
---
{{
card_data
}}
---
<!-- This model card has been generated automatically according to the information the training script had access to. You
should probably proofread and complete it, then remove this comment. -->
# {{ model_name | default("Diffusion Model") }}
## Model description
This diffusion model is trained with the
[
🤗 Diffusers
](
https://github.com/huggingface/diffusers
)
library
on the
`{{ dataset_name }}`
dataset.
## Intended uses & limitations
#### How to use
```
python
# TODO: add an example code snippet for running this diffusion pipeline
```
#### Limitations and bias
[TODO: provide examples of latent issues and potential remediations]
## Training data
[TODO: describe the data used to train the model]
### Training hyperparameters
The following hyperparameters were used during training:
-
learning_rate: {{ learning_rate }}
-
train_batch_size: {{ train_batch_size }}
-
eval_batch_size: {{ eval_batch_size }}
-
gradient_accumulation_steps: {{ gradient_accumulation_steps }}
-
optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
-
lr_scheduler: {{ lr_scheduler }}
-
lr_warmup_steps: {{ lr_warmup_steps }}
-
ema_inv_gamma: {{ ema_inv_gamma }}
-
ema_inv_gamma: {{ ema_power }}
-
ema_inv_gamma: {{ ema_max_decay }}
-
mixed_precision: {{ mixed_precision }}
### Training results
📈
[
TensorBoard logs
](
https://huggingface.co/{{
repo_name }}/tensorboard?#scalars)
tests/test_modeling_utils.py
View file @
df2e145e
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
import
torch
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKL
,
BDDMPipeline
,
BDDMPipeline
,
DDIMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDIMScheduler
,
...
@@ -33,6 +34,7 @@ from diffusers import (
...
@@ -33,6 +34,7 @@ from diffusers import (
GradTTSPipeline
,
GradTTSPipeline
,
GradTTSScheduler
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
LatentDiffusionUncondPipeline
,
NCSNpp
,
NCSNpp
,
PNDMPipeline
,
PNDMPipeline
,
PNDMScheduler
,
PNDMScheduler
,
...
@@ -44,6 +46,7 @@ from diffusers import (
...
@@ -44,6 +46,7 @@ from diffusers import (
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetModel
,
VQModel
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -805,6 +808,154 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -805,6 +808,154 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
VQModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
image
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
return
{
"x"
:
image
}
@
property
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
64
,
"out_ch"
:
3
,
"num_res_blocks"
:
1
,
"attn_resolutions"
:
[],
"in_channels"
:
3
,
"resolution"
:
32
,
"z_channels"
:
3
,
"n_embed"
:
256
,
"embed_dim"
:
3
,
"sane_index_shape"
:
False
,
"ch_mult"
:
(
1
,),
"dropout"
:
0.0
,
"double_z"
:
False
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_forward_signature
(
self
):
pass
def
test_training
(
self
):
pass
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
with
torch
.
no_grad
():
output
=
model
(
image
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
-
0.4218
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
image
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
return
{
"x"
:
image
}
@
property
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
64
,
"ch_mult"
:
(
1
,),
"embed_dim"
:
4
,
"in_channels"
:
3
,
"num_res_blocks"
:
1
,
"out_ch"
:
3
,
"resolution"
:
32
,
"z_channels"
:
4
,
"attn_resolutions"
:
[],
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_forward_signature
(
self
):
pass
def
test_training
(
self
):
pass
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
with
torch
.
no_grad
():
output
=
model
(
image
,
sample_posterior
=
True
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
...
@@ -1000,6 +1151,19 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1000,6 +1151,19 @@ class PipelineTesterMixin(unittest.TestCase):
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
@
slow
def
test_ldm_uncond
(
self
):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.5025
,
0.4121
,
0.3851
,
0.4806
,
0.3996
,
0.3745
,
0.4839
,
0.4559
,
0.4293
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment