Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
26ea58d4
Commit
26ea58d4
authored
Jun 27, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
d1fb3093
4261c3aa
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
568 additions
and
319 deletions
+568
-319
Makefile
Makefile
+2
-9
examples/experimental/train_glide_text_to_image.py
examples/experimental/train_glide_text_to_image.py
+201
-0
examples/train_latent_text_to_image.py
examples/train_latent_text_to_image.py
+45
-31
examples/train_unconditional.py
examples/train_unconditional.py
+6
-3
src/diffusers/hub_utils.py
src/diffusers/hub_utils.py
+6
-7
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+14
-14
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+3
-5
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+84
-6
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+26
-40
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+9
-9
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+70
-80
src/diffusers/models/unet_rl.py
src/diffusers/models/unet_rl.py
+37
-18
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+38
-59
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+3
-3
src/diffusers/pipelines/pipeline_bddm.py
src/diffusers/pipelines/pipeline_bddm.py
+2
-4
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+2
-3
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+1
-1
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+5
-7
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+5
-6
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+9
-14
No files found.
Makefile
View file @
26ea58d4
...
@@ -34,13 +34,9 @@ autogenerate_code: deps_table_update
...
@@ -34,13 +34,9 @@ autogenerate_code: deps_table_update
# Check that the repo is in a good state
# Check that the repo is in a good state
repo-consistency
:
repo-consistency
:
python utils/check_copies.py
python utils/check_table.py
python utils/check_dummies.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_repo.py
python utils/check_inits.py
python utils/check_inits.py
python utils/check_config_docstrings.py
python utils/tests_fetcher.py
--sanity_check
# this target runs checks on all files
# this target runs checks on all files
...
@@ -48,14 +44,13 @@ quality:
...
@@ -48,14 +44,13 @@ quality:
black
--check
--preview
$(check_dirs)
black
--check
--preview
$(check_dirs)
isort
--check-only
$(check_dirs)
isort
--check-only
$(check_dirs)
flake8
$(check_dirs)
flake8
$(check_dirs)
doc-builder style src/
transform
ers docs/source
--max_len
119
--check_only
--path_to_docs
docs/source
doc-builder style src/
diffus
ers docs/source
--max_len
119
--check_only
--path_to_docs
docs/source
# Format source code automatically and check is there are any problems left that need manual fixing
# Format source code automatically and check is there are any problems left that need manual fixing
extra_style_checks
:
extra_style_checks
:
python utils/custom_init_isort.py
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/diffusers docs/source
--max_len
119
--path_to_docs
docs/source
doc-builder style src/transformers docs/source
--max_len
119
--path_to_docs
docs/source
# this target runs checks on all files and potentially modifies some of them
# this target runs checks on all files and potentially modifies some of them
...
@@ -73,8 +68,6 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
...
@@ -73,8 +68,6 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
fix-copies
:
fix-copies
:
python utils/check_dummies.py
--fix_and_overwrite
python utils/check_dummies.py
--fix_and_overwrite
python utils/check_table.py
--fix_and_overwrite
python utils/check_copies.py
--fix_and_overwrite
# Run tests for the library
# Run tests for the library
...
...
examples/experimental/train_glide_text_to_image.py
0 → 100644
View file @
26ea58d4
import
argparse
import
os
import
torch
import
torch.nn.functional
as
F
import
bitsandbytes
as
bnb
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPMScheduler
,
Glide
,
GlideUNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
CenterCrop
,
Compose
,
InterpolationMode
,
Normalize
,
RandomHorizontalFlip
,
Resize
,
ToTensor
,
)
from
tqdm.auto
import
tqdm
logger
=
logging
.
get_logger
(
__name__
)
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
pipeline
=
Glide
.
from_pretrained
(
"fusing/glide-base"
)
model
=
pipeline
.
text_unet
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
bnb
.
optim
.
Adam8bit
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
[
Resize
(
args
.
resolution
,
interpolation
=
InterpolationMode
.
BILINEAR
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Normalize
([
0.5
],
[
0.5
]),
]
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
text_encoder
=
pipeline
.
text_encoder
.
eval
()
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
text_inputs
=
pipeline
.
tokenizer
(
examples
[
"caption"
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
text_inputs
=
text_inputs
.
input_ids
.
to
(
accelerator
.
device
)
with
torch
.
no_grad
():
text_embeddings
=
accelerator
.
unwrap_model
(
text_encoder
)(
text_inputs
).
last_hidden_state
return
{
"images"
:
images
,
"text_embeddings"
:
text_embeddings
}
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
lr_scheduler
=
get_scheduler
(
"linear"
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
model
,
text_encoder
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
text_encoder
,
optimizer
,
train_dataloader
,
lr_scheduler
)
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
# Train!
is_distributed
=
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
is_distributed
else
1
total_train_batch_size
=
args
.
batch_size
*
args
.
gradient_accumulation_steps
*
world_size
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"images"
]
batch_size
,
n_channels
,
height
,
width
=
clean_images
.
shape
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
batch_size
,),
device
=
clean_images
.
device
).
long
()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images
=
noise_scheduler
.
training_step
(
clean_images
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
model_output
=
model
(
noisy_images
,
timesteps
,
batch
[
"text_embeddings"
])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
n_channels
,
dim
=
1
)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out
=
torch
.
cat
([
model_output
.
detach
(),
model_var_values
],
dim
=
1
)
# predict the noise residual
loss
=
F
.
mse_loss
(
model_output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
optimizer
.
step
()
else
:
model_output
=
model
(
noisy_images
,
timesteps
,
batch
[
"text_embeddings"
])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
n_channels
,
dim
=
1
)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out
=
torch
.
cat
([
model_output
.
detach
(),
model_var_values
],
dim
=
1
)
# predict the noise residual
loss
=
F
.
mse_loss
(
model_output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
accelerator
.
wait_for_everyone
()
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
model
.
eval
()
with
torch
.
no_grad
():
pipeline
.
unet
=
accelerator
.
unwrap_model
(
model
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
"a clip art of a corgi"
,
generator
=
generator
,
num_upscale_inference_steps
=
50
)
# process image to PIL
image_processed
=
image
.
squeeze
(
0
)
image_processed
=
((
image_processed
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
)
# save image
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
if
args
.
push_to_hub
:
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
accelerator
.
wait_for_everyone
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"fusing/dog_captions"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"glide-text2image"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_private_repo"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"no"
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args
=
parser
.
parse_args
()
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
if
env_local_rank
!=
-
1
and
env_local_rank
!=
args
.
local_rank
:
args
.
local_rank
=
env_local_rank
main
(
args
)
examples/train_latent_text_to_image.py
View file @
26ea58d4
...
@@ -4,19 +4,19 @@ import os
...
@@ -4,19 +4,19 @@ import os
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
bitsandbytes
as
bnb
import
PIL.Image
import
PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetLDMModel
from
diffusers
import
DDPMScheduler
,
LatentDiffusion
,
UNetLDMModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
InterpolationMode
,
InterpolationMode
,
Lambda
,
Normalize
,
RandomHorizontalFlip
,
RandomHorizontalFlip
,
Resize
,
Resize
,
ToTensor
,
ToTensor
,
...
@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
...
@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
def
main
(
args
):
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
pipeline
=
LatentDiffusion
.
from_pretrained
(
"fusing/latent-diffusion-text2im-large"
)
pipeline
.
unet
=
None
# this model will be trained from scratch now
model
=
UNetLDMModel
(
model
=
UNetLDMModel
(
attention_resolutions
=
[
4
,
2
,
1
],
attention_resolutions
=
[
4
,
2
,
1
],
channel_mult
=
[
1
,
2
,
4
,
4
],
channel_mult
=
[
1
,
2
,
4
,
4
],
...
@@ -37,7 +39,7 @@ def main(args):
...
@@ -37,7 +39,7 @@ def main(args):
conv_resample
=
True
,
conv_resample
=
True
,
dims
=
2
,
dims
=
2
,
dropout
=
0
,
dropout
=
0
,
image_size
=
32
,
image_size
=
8
,
in_channels
=
4
,
in_channels
=
4
,
model_channels
=
320
,
model_channels
=
320
,
num_heads
=
8
,
num_heads
=
8
,
...
@@ -51,7 +53,7 @@ def main(args):
...
@@ -51,7 +53,7 @@ def main(args):
legacy
=
False
,
legacy
=
False
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
bnb
.
optim
.
Adam
8bit
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
augmentations
=
Compose
(
[
[
...
@@ -59,14 +61,22 @@ def main(args):
...
@@ -59,14 +61,22 @@ def main(args):
CenterCrop
(
args
.
resolution
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
RandomHorizontalFlip
(),
ToTensor
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Normalize
([
0.5
],
[
0.5
]
),
]
]
)
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
text_encoder
=
pipeline
.
bert
.
eval
()
vqvae
=
pipeline
.
vqvae
.
eval
()
def
transforms
(
examples
):
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
text_inputs
=
pipeline
.
tokenizer
(
examples
[
"caption"
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
text_embeddings
=
accelerator
.
unwrap_model
(
text_encoder
)(
text_inputs
.
input_ids
.
cpu
()).
last_hidden_state
images
=
1
/
0.18215
*
torch
.
stack
(
images
,
dim
=
0
)
latents
=
accelerator
.
unwrap_model
(
vqvae
).
encode
(
images
.
cpu
()).
mode
()
return
{
"images"
:
images
,
"text_embeddings"
:
text_embeddings
,
"latents"
:
latents
}
dataset
.
set_transform
(
transforms
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
...
@@ -78,9 +88,11 @@ def main(args):
...
@@ -78,9 +88,11 @@ def main(args):
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
text_encoder
,
vqvae
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
text_encoder
,
vqvae
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
text_encoder
=
text_encoder
.
cpu
()
vqvae
=
vqvae
.
cpu
()
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -98,29 +110,31 @@ def main(args):
...
@@ -98,29 +110,31 @@ def main(args):
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_
image
s
=
batch
[
"
input
"
]
clean_
latent
s
=
batch
[
"
latents
"
]
noise_samples
=
torch
.
randn
(
clean_
image
s
.
shape
).
to
(
clean_
image
s
.
device
)
noise_samples
=
torch
.
randn
(
clean_
latent
s
.
shape
).
to
(
clean_
latent
s
.
device
)
bsz
=
clean_
image
s
.
shape
[
0
]
bsz
=
clean_
latent
s
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_
image
s
.
device
).
long
()
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_
latent
s
.
device
).
long
()
# add noise onto the clean
image
s according to the noise magnitude at each timestep
# add noise onto the clean
latent
s according to the noise magnitude at each timestep
# (this is the forward diffusion process)
# (this is the forward diffusion process)
noisy_
image
s
=
noise_scheduler
.
training_step
(
clean_
image
s
,
noise_samples
,
timesteps
)
noisy_
latent
s
=
noise_scheduler
.
training_step
(
clean_
latent
s
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_
image
s
,
timesteps
)
output
=
model
(
noisy_
latent
s
,
timesteps
,
context
=
batch
[
"text_embeddings"
]
)
# predict the noise residual
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
optimizer
.
step
()
else
:
else
:
output
=
model
(
noisy_
image
s
,
timesteps
)
output
=
model
(
noisy_
latent
s
,
timesteps
,
context
=
batch
[
"text_embeddings"
]
)
# predict the noise residual
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
loss
=
loss
/
args
.
gradient_accumulation_steps
...
@@ -131,24 +145,25 @@ def main(args):
...
@@ -131,24 +145,25 @@ def main(args):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
global_step
+=
1
optimizer
.
step
()
accelerator
.
wait_for_everyone
()
if
is_distributed
:
torch
.
distributed
.
barrier
()
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
a
rgs
.
local_rank
in
[
-
1
,
0
]
:
if
a
ccelerator
.
is_main_process
:
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
pipeline
.
unet
=
accelerator
.
unwrap_model
(
model
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
image
=
pipeline
(
[
"a clip art of a corgi"
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
# process image to PIL
# process image to PIL
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
(
image_processed
+
1.0
)
*
127.5
image_processed
=
image_processed
*
255.0
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
@@ -162,20 +177,19 @@ def main(args):
...
@@ -162,20 +177,19 @@ def main(args):
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
is_distributed
:
accelerator
.
wait_for_everyone
()
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"
huggan/flowers-102-categorie
s"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"
fusing/dog_caption
s"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"
ddpm-model
"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"
ldm-text2image
"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
6
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
6
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
...
...
examples/train_unconditional.py
View file @
26ea58d4
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
import
PIL.Image
import
PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
Pipeline
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
...
@@ -71,7 +71,7 @@ def main(args):
...
@@ -71,7 +71,7 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
1.0
,
power
=
3
/
4
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
args
.
ema_inv_gamma
,
power
=
args
.
ema_power
,
max_value
=
args
.
ema_max_decay
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -133,7 +133,7 @@ def main(args):
...
@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
pipeline
=
DDPM
Pipeline
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
)
)
...
@@ -172,6 +172,9 @@ if __name__ == "__main__":
...
@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
...
...
src/diffusers/hub_utils.py
View file @
26ea58d4
...
@@ -47,12 +47,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
...
@@ -47,12 +47,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
def
init_git_repo
(
args
,
at_init
:
bool
=
False
):
def
init_git_repo
(
args
,
at_init
:
bool
=
False
):
"""
"""
Initializes a git repo in `args.hub_model_id`.
Args:
Args:
Initializes a git repo in `args.hub_model_id`.
at_init (`bool`, *optional*, defaults to `False`):
at_init (`bool`, *optional*, defaults to `False`):
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
`True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
out.
"""
"""
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
args
.
local_rank
not
in
[
-
1
,
0
]:
return
return
...
@@ -102,8 +101,8 @@ def push_to_hub(
...
@@ -102,8 +101,8 @@ def push_to_hub(
**
kwargs
,
**
kwargs
,
)
->
str
:
)
->
str
:
"""
"""
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters:
Parameters:
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
commit_message (`str`, *optional*, defaults to `"End of training"`):
commit_message (`str`, *optional*, defaults to `"End of training"`):
Message to commit while pushing.
Message to commit while pushing.
blocking (`bool`, *optional*, defaults to `True`):
blocking (`bool`, *optional*, defaults to `True`):
...
@@ -111,8 +110,8 @@ def push_to_hub(
...
@@ -111,8 +110,8 @@ def push_to_hub(
kwargs:
kwargs:
Additional keyword arguments passed along to [`create_model_card`].
Additional keyword arguments passed along to [`create_model_card`].
Returns:
Returns:
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
the
the
commit and an object to track the progress of the commit if `blocking=True`
commit and an object to track the progress of the commit if `blocking=True`
"""
"""
if
args
.
hub_model_id
is
None
:
if
args
.
hub_model_id
is
None
:
...
...
src/diffusers/modeling_utils.py
View file @
26ea58d4
...
@@ -123,16 +123,16 @@ class ModelMixin(torch.nn.Module):
...
@@ -123,16 +123,16 @@ class ModelMixin(torch.nn.Module):
r
"""
r
"""
Base class for all models.
Base class for all models.
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
downloading
downloading
and saving models as well as a few methods common to all models to:
and saving models as well as a few methods common to all models to:
- resize the input embeddings,
- resize the input embeddings,
- prune heads in the self-attention heads.
- prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
Class attributes (overridden by derived classes):
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class
for this
for this
model architecture.
model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
taking as arguments:
...
@@ -227,8 +227,8 @@ class ModelMixin(torch.nn.Module):
...
@@ -227,8 +227,8 @@ class ModelMixin(torch.nn.Module):
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
- A path to a *directory* containing model weights saved using
[`~ModelMixin.save_pretrained`],
[`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.
e.g., `./my_model_directory/`.
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
Can be either:
Can be either:
...
@@ -236,13 +236,13 @@ class ModelMixin(torch.nn.Module):
...
@@ -236,13 +236,13 @@ class ModelMixin(torch.nn.Module):
- an instance of a class derived from [`ConfigMixin`],
- an instance of a class derived from [`ConfigMixin`],
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
ConfigMixinuration can
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
be automatically loaded when:
ConfigMixinuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
model).
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the
save
save
directory.
directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
configuration JSON file named *config.json* is found in the directory.
cache_dir (`Union[str, os.PathLike]`, *optional*):
cache_dir (`Union[str, os.PathLike]`, *optional*):
...
@@ -292,10 +292,10 @@ class ModelMixin(torch.nn.Module):
...
@@ -292,10 +292,10 @@ class ModelMixin(torch.nn.Module):
underlying model's `__init__` method (we assume all relevant updates to the configuration have
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that
corresponds
corresponds
to a configuration attribute will be used to override said attribute with the
to a configuration attribute will be used to override said attribute with the
supplied `kwargs`
supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute
value. Remaining keys that do not correspond to any configuration attribute
will be passed to the
will be passed to the
underlying model's `__init__` function.
underlying model's `__init__` function.
<Tip>
<Tip>
...
...
src/diffusers/models/embeddings.py
View file @
26ea58d4
...
@@ -22,14 +22,12 @@ def get_timestep_embedding(
...
@@ -22,14 +22,12 @@ def get_timestep_embedding(
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
scale
=
1
,
max_period
=
10000
timesteps
,
embedding_dim
,
flip_sin_to_cos
=
False
,
downscale_freq_shift
=
1
,
scale
=
1
,
max_period
=
10000
):
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
These may be fractional.
:param embedding_dim: the dimension of the output.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
:param max_period: controls the minimum frequency of the embeddings.
embeddings. :return: an [N x dim] Tensor of positional embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
"""
assert
len
(
timesteps
.
shape
)
==
1
,
"Timesteps should be a 1d-array"
assert
len
(
timesteps
.
shape
)
==
1
,
"Timesteps should be a 1d-array"
...
...
src/diffusers/models/resnet.py
View file @
26ea58d4
...
@@ -58,9 +58,8 @@ class Upsample(nn.Module):
...
@@ -58,9 +58,8 @@ class Upsample(nn.Module):
"""
"""
An upsampling layer with an optional convolution.
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
:param use_conv: a bool determining if a convolution is applied.
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
...
@@ -97,9 +96,8 @@ class Downsample(nn.Module):
...
@@ -97,9 +96,8 @@ class Downsample(nn.Module):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
:param use_conv: a bool determining if a convolution is applied.
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
downsampling occurs in the inner-two dimensions.
"""
"""
...
@@ -136,6 +134,86 @@ class Downsample(nn.Module):
...
@@ -136,6 +134,86 @@ class Downsample(nn.Module):
return
self
.
op
(
x
)
return
self
.
op
(
x
)
class
UNetUpsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GlideUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
LDMUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D.
If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
# TODO (patil-suraj): needs test
# TODO (patil-suraj): needs test
class
Upsample1d
(
nn
.
Module
):
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
...
...
src/diffusers/models/unet_glide.py
View file @
26ea58d4
...
@@ -82,8 +82,7 @@ def normalization(channels, swish=0.0):
...
@@ -82,8 +82,7 @@ def normalization(channels, swish=0.0):
"""
"""
Make a standard normalization layer, with an optional swish activation.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:param channels: number of input channels. :return: an nn.Module for normalization.
:return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
...
@@ -111,8 +110,7 @@ class TimestepBlock(nn.Module):
...
@@ -111,8 +110,7 @@ class TimestepBlock(nn.Module):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
"""
A sequential module that passes timestep embeddings to the children that
A sequential module that passes timestep embeddings to the children that support it as an extra input.
support it as an extra input.
"""
"""
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
def
forward
(
self
,
x
,
emb
,
encoder_out
=
None
):
...
@@ -130,17 +128,13 @@ class ResBlock(TimestepBlock):
...
@@ -130,17 +128,13 @@ class ResBlock(TimestepBlock):
"""
"""
A residual block that can optionally change the number of channels.
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param
:param dropout: the rate of dropout.
use_conv: if True and out_channels is specified, use a spatial
:param out_channels: if specified, the number of out channels.
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param use_conv: if True and out_channels is specified, use a spatial
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
convolution instead of a smaller 1x1 convolution to change the
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
channels in the skip connection.
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -207,8 +201,7 @@ class ResBlock(TimestepBlock):
...
@@ -207,8 +201,7 @@ class ResBlock(TimestepBlock):
"""
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
if
self
.
updown
:
if
self
.
updown
:
...
@@ -292,8 +285,8 @@ class QKVAttention(nn.Module):
...
@@ -292,8 +285,8 @@ class QKVAttention(nn.Module):
"""
"""
Apply QKV attention.
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after
:return: an [N x (H * C) x T] tensor after
attention.
attention.
"""
"""
bs
,
width
,
length
=
qkv
.
shape
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
assert
width
%
(
3
*
self
.
n_heads
)
==
0
...
@@ -315,29 +308,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -315,29 +308,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param
:param model_channels: base channel count for the model.
out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
For example, if this contains 4, then at 4x downsampling, attention
downsampling, attention will be used.
will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
:param dropout: the dropout probability.
conv_resample: if True, use learned convolutions for upsampling and
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this
:param num_classes: if specified (as an int), then this
model will be
model will be
class-conditional with `num_classes` classes.
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
:param num_heads: the number of attention heads in each attention layer.
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks
:param resblock_updown: use residual blocks
for up/downsampling.
for up/downsampling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -545,10 +533,8 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -545,10 +533,8 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N]
:param timesteps: a 1-D batch of timesteps.
Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
hs
=
[]
hs
=
[]
...
...
src/diffusers/models/unet_grad_tts.py
View file @
26ea58d4
import
torch
import
torch
from
numpy
import
pad
from
numpy
import
pad
try
:
from
einops
import
rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
...
@@ -65,6 +58,7 @@ class LinearAttention(torch.nn.Module):
...
@@ -65,6 +58,7 @@ class LinearAttention(torch.nn.Module):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
self
.
heads
=
heads
self
.
dim_head
=
dim_head
hidden_dim
=
dim_head
*
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
...
@@ -72,11 +66,17 @@ class LinearAttention(torch.nn.Module):
...
@@ -72,11 +66,17 @@ class LinearAttention(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
"b (qkv heads c) h w -> qkv b heads c (h w)"
,
heads
=
self
.
heads
,
qkv
=
3
)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q
,
k
,
v
=
(
qkv
.
reshape
(
b
,
3
,
self
.
heads
,
self
.
dim_head
,
h
,
w
)
.
permute
(
1
,
0
,
2
,
3
,
4
,
5
)
.
reshape
(
3
,
b
,
self
.
heads
,
self
.
dim_head
,
-
1
)
)
k
=
k
.
softmax
(
dim
=-
1
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
context
=
torch
.
einsum
(
"bhdn,bhen->bhde"
,
k
,
v
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
torch
.
einsum
(
"bhde,bhdn->bhen"
,
context
,
q
)
out
=
rearrange
(
out
,
"b heads c (h w) -> b (heads c) h w"
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out
=
out
.
reshape
(
b
,
self
.
heads
,
self
.
dim_head
,
h
,
w
).
reshape
(
b
,
self
.
heads
*
self
.
dim_head
,
h
,
w
)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
...
src/diffusers/models/unet_ldm.py
View file @
26ea58d4
...
@@ -7,13 +7,6 @@ import torch
...
@@ -7,13 +7,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
try
:
from
einops
import
rearrange
,
repeat
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
...
@@ -155,7 +148,23 @@ class CrossAttention(nn.Module):
...
@@ -155,7 +148,23 @@ class CrossAttention(nn.Module):
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
def
reshape_heads_to_batch_dim
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
,
seq_len
,
head_size
,
dim
//
head_size
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
*
head_size
,
seq_len
,
dim
//
head_size
)
return
tensor
def
reshape_batch_dim_to_heads
(
self
,
tensor
):
batch_size
,
seq_len
,
dim
=
tensor
.
shape
head_size
=
self
.
heads
tensor
=
tensor
.
reshape
(
batch_size
//
head_size
,
head_size
,
seq_len
,
dim
)
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
batch_size
//
head_size
,
seq_len
,
dim
*
head_size
)
return
tensor
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
batch_size
,
sequence_length
,
dim
=
x
.
shape
h
=
self
.
heads
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
q
=
self
.
to_q
(
x
)
...
@@ -163,21 +172,29 @@ class CrossAttention(nn.Module):
...
@@ -163,21 +172,29 @@ class CrossAttention(nn.Module):
k
=
self
.
to_k
(
context
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b n (h d) -> (b h) n d"
,
h
=
h
),
(
q
,
k
,
v
))
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q
=
self
.
reshape_heads_to_batch_dim
(
q
)
k
=
self
.
reshape_heads_to_batch_dim
(
k
)
v
=
self
.
reshape_heads_to_batch_dim
(
v
)
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
sim
=
torch
.
einsum
(
"b i d, b j d -> b i j"
,
q
,
k
)
*
self
.
scale
if
exists
(
mask
):
if
exists
(
mask
):
mask
=
rearrange
(
mask
,
"b ... -> b (...)"
)
# mask = rearrange(mask, "b ... -> b (...)")
maks
=
mask
.
reshape
(
batch_size
,
-
1
)
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
max_neg_value
=
-
torch
.
finfo
(
sim
.
dtype
).
max
mask
=
repeat
(
mask
,
"b j -> (b h) () j"
,
h
=
h
)
# mask = repeat(mask, "b j -> (b h) () j", h=h)
mask
=
mask
[:,
None
,
:].
repeat
(
h
,
1
,
1
)
# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
sim
.
masked_fill_
(
~
mask
,
max_neg_value
)
# attention, what we cannot get enough of
# attention, what we cannot get enough of
attn
=
sim
.
softmax
(
dim
=-
1
)
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
torch
.
einsum
(
"b i j, b j d -> b i d"
,
attn
,
v
)
out
=
rearrange
(
out
,
"(b h) n d -> b n (h d)"
,
h
=
h
)
out
=
self
.
reshape_batch_dim_to_heads
(
out
)
# out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return
self
.
to_out
(
out
)
return
self
.
to_out
(
out
)
...
@@ -205,11 +222,8 @@ class BasicTransformerBlock(nn.Module):
...
@@ -205,11 +222,8 @@ class BasicTransformerBlock(nn.Module):
class
SpatialTransformer
(
nn
.
Module
):
class
SpatialTransformer
(
nn
.
Module
):
"""
"""
Transformer block for image-like data.
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
First, project the input (aka embedding)
standard transformer action. Finally, reshape to image
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
"""
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
...
@@ -235,10 +249,10 @@ class SpatialTransformer(nn.Module):
...
@@ -235,10 +249,10 @@ class SpatialTransformer(nn.Module):
x_in
=
x
x_in
=
x
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
self
.
proj_in
(
x
)
x
=
rearrang
e
(
x
,
"b c h w -> b (h w)
c
"
)
x
=
x
.
permut
e
(
0
,
2
,
3
,
1
).
reshape
(
b
,
h
*
w
,
c
)
for
block
in
self
.
transformer_blocks
:
for
block
in
self
.
transformer_blocks
:
x
=
block
(
x
,
context
=
context
)
x
=
block
(
x
,
context
=
context
)
x
=
rearrange
(
x
,
"b (h
w
)
c
-> b c h w"
,
h
=
h
,
w
=
w
)
x
=
x
.
reshape
(
b
,
h
,
w
,
c
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
proj_out
(
x
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
return
x
+
x_in
...
@@ -314,8 +328,7 @@ def normalization(channels, swish=0.0):
...
@@ -314,8 +328,7 @@ def normalization(channels, swish=0.0):
"""
"""
Make a standard normalization layer, with an optional swish activation.
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels.
:param channels: number of input channels. :return: an nn.Module for normalization.
:return: an nn.Module for normalization.
"""
"""
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
return
GroupNorm32
(
num_channels
=
channels
,
num_groups
=
32
,
swish
=
swish
)
...
@@ -365,8 +378,7 @@ class TimestepBlock(nn.Module):
...
@@ -365,8 +378,7 @@ class TimestepBlock(nn.Module):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
class
TimestepEmbedSequential
(
nn
.
Sequential
,
TimestepBlock
):
"""
"""
A sequential module that passes timestep embeddings to the children that
A sequential module that passes timestep embeddings to the children that support it as an extra input.
support it as an extra input.
"""
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
):
...
@@ -382,18 +394,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -382,18 +394,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
class
ResBlock
(
TimestepBlock
):
class
ResBlock
(
TimestepBlock
):
"""
"""
A residual block that can optionally change the number of channels.
A residual block that can optionally change the number of channels. :param channels: the number of input channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param
:param emb_channels: the number of timestep embedding channels.
out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use
:param dropout: the rate of dropout.
a spatial
:param out_channels: if specified, the number of out channels.
convolution instead of a smaller 1x1 convolution to change the channels in the skip connection.
:param use_conv: if True and out_channels is specified, use a spatial
:param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing
convolution instead of a smaller 1x1 convolution to change the
on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for
channels in the skip connection.
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -481,8 +489,8 @@ class ResBlock(TimestepBlock):
...
@@ -481,8 +489,8 @@ class ResBlock(TimestepBlock):
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
An attention block that allows spatial positions to attend to each other.
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted
Originally ported from here, but adapted
to the N-d case.
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
"""
...
@@ -531,9 +539,8 @@ class QKVAttention(nn.Module):
...
@@ -531,9 +539,8 @@ class QKVAttention(nn.Module):
def
forward
(
self
,
qkv
):
def
forward
(
self
,
qkv
):
"""
"""
Apply QKV attention.
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
T] tensor after attention.
:return: an [N x (H * C) x T] tensor after attention.
"""
"""
bs
,
width
,
length
=
qkv
.
shape
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
assert
width
%
(
3
*
self
.
n_heads
)
==
0
...
@@ -556,13 +563,9 @@ class QKVAttention(nn.Module):
...
@@ -556,13 +563,9 @@ class QKVAttention(nn.Module):
def
count_flops_attn
(
model
,
_x
,
y
):
def
count_flops_attn
(
model
,
_x
,
y
):
"""
"""
A counter for the `thop` package to count the operations in an
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
attention operation.
Meant to be used like:
macs, params = thop.profile(
macs, params = thop.profile(
model,
model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops},
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
)
"""
"""
b
,
c
,
*
spatial
=
y
[
0
].
shape
b
,
c
,
*
spatial
=
y
[
0
].
shape
...
@@ -585,9 +588,8 @@ class QKVAttentionLegacy(nn.Module):
...
@@ -585,9 +588,8 @@ class QKVAttentionLegacy(nn.Module):
def
forward
(
self
,
qkv
):
def
forward
(
self
,
qkv
):
"""
"""
Apply QKV attention.
Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
T] tensor after attention.
:return: an [N x (H * C) x T] tensor after attention.
"""
"""
bs
,
width
,
length
=
qkv
.
shape
bs
,
width
,
length
=
qkv
.
shape
assert
width
%
(
3
*
self
.
n_heads
)
==
0
assert
width
%
(
3
*
self
.
n_heads
)
==
0
...
@@ -606,31 +608,25 @@ class QKVAttentionLegacy(nn.Module):
...
@@ -606,31 +608,25 @@ class QKVAttentionLegacy(nn.Module):
class
UNetLDMModel
(
ModelMixin
,
ConfigMixin
):
class
UNetLDMModel
(
ModelMixin
,
ConfigMixin
):
"""
"""
The full UNet model with attention and timestep embedding.
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
:param in_channels: channels in the input Tensor.
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
:param model_channels: base channel count for the model.
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
:param out_channels: channels in the output Tensor.
rates at which
:param num_res_blocks: number of residual blocks per downsample.
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
:param attention_resolutions: a collection of downsample rates at which
downsampling, attention will be used.
attention will take place. May be a set, list, or tuple.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
For example, if this contains 4, then at 4x downsampling, attention
conv_resample: if True, use learned convolutions for upsampling and
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this
:param num_classes: if specified (as an int), then this
model will be
model will be
class-conditional with `num_classes` classes.
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
:param num_heads: the number of attention heads in each attention layer.
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
:param resblock_updown: use residual blocks for up/downsampling.
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
increased efficiency.
"""
"""
...
@@ -933,12 +929,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -933,12 +929,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
:param x: an [N x C x ...] Tensor of inputs.
of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if
:param timesteps: a 1-D batch of timesteps.
class-conditional. :return: an [N x C x ...] Tensor of outputs.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
assert
(
y
is
not
None
)
==
(
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
self
.
num_classes
is
not
None
...
@@ -970,8 +963,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -970,8 +963,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
class
EncoderUNetModel
(
nn
.
Module
):
class
EncoderUNetModel
(
nn
.
Module
):
"""
"""
The half UNet model with attention and timestep embedding.
The half UNet model with attention and timestep embedding. For usage, see UNet.
For usage, see UNet.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -1157,10 +1149,8 @@ class EncoderUNetModel(nn.Module):
...
@@ -1157,10 +1149,8 @@ class EncoderUNetModel(nn.Module):
def
forward
(
self
,
x
,
timesteps
):
def
forward
(
self
,
x
,
timesteps
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
:param x: an [N x C x ...] Tensor of inputs.
of timesteps. :return: an [N x K] Tensor of outputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
"""
emb
=
self
.
time_embed
(
emb
=
self
.
time_embed
(
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
get_timestep_embedding
(
timesteps
,
self
.
model_channels
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
...
...
src/diffusers/models/unet_rl.py
View file @
26ea58d4
...
@@ -5,18 +5,18 @@ import math
...
@@ -5,18 +5,18 @@ import math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
try
:
import
einops
from
einops.layers.torch
import
Rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
# try:
# import einops
# from einops.layers.torch import Rearrange
# except:
# print("Einops is not installed")
# pass
class
SinusoidalPosEmb
(
nn
.
Module
):
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
@@ -50,6 +50,21 @@ class Upsample1d(nn.Module):
...
@@ -50,6 +50,21 @@ class Upsample1d(nn.Module):
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
class
RearrangeDim
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
tensor
):
if
len
(
tensor
.
shape
)
==
2
:
return
tensor
[:,
:,
None
]
if
len
(
tensor
.
shape
)
==
3
:
return
tensor
[:,
:,
None
,
:]
elif
len
(
tensor
.
shape
)
==
4
:
return
tensor
[:,
:,
0
,
:]
else
:
raise
ValueError
(
f
"`len(tensor)`:
{
len
(
tensor
)
}
has to be 2, 3 or 4."
)
class
Conv1dBlock
(
nn
.
Module
):
class
Conv1dBlock
(
nn
.
Module
):
"""
"""
Conv1d --> GroupNorm --> Mish
Conv1d --> GroupNorm --> Mish
...
@@ -60,9 +75,11 @@ class Conv1dBlock(nn.Module):
...
@@ -60,9 +75,11 @@ class Conv1dBlock(nn.Module):
self
.
block
=
nn
.
Sequential
(
self
.
block
=
nn
.
Sequential
(
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
nn
.
Conv1d
(
inp_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
),
Rearrange
(
"batch channels horizon -> batch channels 1 horizon"
),
RearrangeDim
(),
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
nn
.
GroupNorm
(
n_groups
,
out_channels
),
Rearrange
(
"batch channels 1 horizon -> batch channels horizon"
),
RearrangeDim
(),
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
nn
.
Mish
(),
nn
.
Mish
(),
)
)
...
@@ -84,7 +101,8 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -84,7 +101,8 @@ class ResidualTemporalBlock(nn.Module):
self
.
time_mlp
=
nn
.
Sequential
(
self
.
time_mlp
=
nn
.
Sequential
(
nn
.
Mish
(),
nn
.
Mish
(),
nn
.
Linear
(
embed_dim
,
out_channels
),
nn
.
Linear
(
embed_dim
,
out_channels
),
Rearrange
(
"batch t -> batch t 1"
),
RearrangeDim
(),
# Rearrange("batch t -> batch t 1"),
)
)
self
.
residual_conv
=
(
self
.
residual_conv
=
(
...
@@ -93,10 +111,8 @@ class ResidualTemporalBlock(nn.Module):
...
@@ -93,10 +111,8 @@ class ResidualTemporalBlock(nn.Module):
def
forward
(
self
,
x
,
t
):
def
forward
(
self
,
x
,
t
):
"""
"""
x : [ batch_size x inp_channels x horizon ]
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
t : [ batch_size x embed_dim ]
out_channels x horizon ]
returns:
out : [ batch_size x out_channels x horizon ]
"""
"""
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
0
](
x
)
+
self
.
time_mlp
(
t
)
out
=
self
.
blocks
[
1
](
out
)
out
=
self
.
blocks
[
1
](
out
)
...
@@ -184,7 +200,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -184,7 +200,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
x
=
einops
.
rearrange
(
x
,
"b h t -> b t h"
)
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
h
=
[]
h
=
[]
...
@@ -206,7 +223,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
...
@@ -206,7 +223,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
x
=
self
.
final_conv
(
x
)
x
=
self
.
final_conv
(
x
)
x
=
einops
.
rearrange
(
x
,
"b t h -> b h t"
)
# x = einops.rearrange(x, "b t h -> b h t")
x
=
x
.
permute
(
0
,
2
,
1
)
return
x
return
x
...
@@ -263,7 +281,8 @@ class TemporalValue(nn.Module):
...
@@ -263,7 +281,8 @@ class TemporalValue(nn.Module):
x : [ batch x horizon x transition ]
x : [ batch x horizon x transition ]
"""
"""
x
=
einops
.
rearrange
(
x
,
"b h t -> b t h"
)
# x = einops.rearrange(x, "b h t -> b t h")
x
=
x
.
permute
(
0
,
2
,
1
)
t
=
self
.
time_mlp
(
time
)
t
=
self
.
time_mlp
(
time
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
26ea58d4
...
@@ -136,26 +136,21 @@ def naive_downsample_2d(x, factor=2):
...
@@ -136,26 +136,21 @@ def naive_downsample_2d(x, factor=2):
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_conv_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
Padding is performed only once at the beginning, not between the
operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
(separable). The default is `[1] * factor`, which corresponds to
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or
Tensor of the shape `[N, C, H * factor, W * factor]` or
`[N, H * factor, W * factor, C]`, and same datatype as
`[N, H * factor, W * factor, C]`, and same datatype as
`x`.
`x`.
"""
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
...
@@ -208,25 +203,21 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
...
@@ -208,25 +203,21 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
def
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
conv_downsample_2d
(
x
,
w
,
k
=
None
,
factor
=
2
,
gain
=
1
):
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations.
The fused op is considerably more efficient than performing the same
calculation
using standard TensorFlow ops. It supports gradients of arbitrary order.
Args:
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels =
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
(separable). The default is `[1] * factor`, which corresponds to
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
average pooling.
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or
Tensor of the shape `[N, C, H // factor, W // factor]` or
`[N, H // factor, W // factor, C]`, and same datatype
`[N, H // factor, W // factor, C]`, and same datatype
as `x`.
as `x`.
"""
"""
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
assert
isinstance
(
factor
,
int
)
and
factor
>=
1
...
@@ -258,22 +249,16 @@ def _shape(x, dim):
...
@@ -258,22 +249,16 @@ def _shape(x, dim):
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
upsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Upsample a batch of 2D images with the given filter.
r
"""Upsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and upsamples each image with the given filter. The filter is normalized so
that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the upsampling factor.
Args:
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer upsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
Tensor of the shape `[N, C, H * factor, W * factor]`
...
@@ -289,22 +274,16 @@ def upsample_2d(x, k=None, factor=2, gain=1):
...
@@ -289,22 +274,16 @@ def upsample_2d(x, k=None, factor=2, gain=1):
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
def
downsample_2d
(
x
,
k
=
None
,
factor
=
2
,
gain
=
1
):
r
"""Downsample a batch of 2D images with the given filter.
r
"""Downsample a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
and downsamples each image with the given filter. The filter is normalized
so that
if the input pixels are constant, they will be scaled by the specified
`gain`.
Pixels outside the image are assumed to be zero, and the filter is padded
with
zeros so that its shape is a multiple of the downsampling factor.
Args:
Args:
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to
(separable). The default is `[1] * factor`, which corresponds to average pooling.
average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
factor: Integer downsampling factor (default: 2).
gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
Tensor of the shape `[N, C, H // factor, W // factor]`
...
...
src/diffusers/pipelines/grad_tts_utils.py
View file @
26ea58d4
...
@@ -290,7 +290,7 @@ def normalize_numbers(text):
...
@@ -290,7 +290,7 @@ def normalize_numbers(text):
return
text
return
text
""" from https://github.com/keithito/tacotron
"""
""" from https://github.com/keithito/tacotron"""
_pad
=
"_"
_pad
=
"_"
...
@@ -322,8 +322,8 @@ def get_arpabet(word, dictionary):
...
@@ -322,8 +322,8 @@ def get_arpabet(word, dictionary):
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on
in it. For example, "Turn left on
{HH AW1 S S T AH0 N} Street."
{HH AW1 S S T AH0 N} Street."
Args:
Args:
text: string to convert to a sequence
text: string to convert to a sequence
...
...
src/diffusers/pipelines/pipeline_bddm.py
View file @
26ea58d4
...
@@ -29,8 +29,7 @@ from ..pipeline_utils import DiffusionPipeline
...
@@ -29,8 +29,7 @@ from ..pipeline_utils import DiffusionPipeline
def
calc_diffusion_step_embedding
(
diffusion_steps
,
diffusion_step_embed_dim_in
):
def
calc_diffusion_step_embedding
(
diffusion_steps
,
diffusion_step_embed_dim_in
):
"""
"""
Embed a diffusion step $t$ into a higher dimensional space
Embed a diffusion step $t$ into a higher dimensional space
E.g. the embedding vector in the 128-dimensional space is
E.g. the embedding vector in the 128-dimensional space is [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
[sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
Parameters:
Parameters:
...
@@ -53,8 +52,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
...
@@ -53,8 +52,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
"""
"""
Below scripts were borrowed from
Below scripts were borrowed from https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
"""
"""
...
...
src/diffusers/pipelines/pipeline_glide.py
View file @
26ea58d4
...
@@ -699,9 +699,8 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
...
@@ -699,9 +699,8 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
"""
Extract values from a 1-D numpy array for a batch of indices.
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param
:param timesteps: a tensor of indices into the array to extract.
broadcast_shape: a larger shape of K dimensions with the batch
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
"""
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
26ea58d4
""" from https://github.com/jaywalnut310/glow-tts
"""
""" from https://github.com/jaywalnut310/glow-tts"""
import
math
import
math
...
...
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
26ea58d4
...
@@ -554,11 +554,9 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -554,11 +554,9 @@ class LDMBertModel(LDMBertPreTrainedModel):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
From Fairseq.
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
Build sinusoidal embeddings.
3.5 of "Attention Is All You Need".
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
"""
assert
len
(
timesteps
.
shape
)
==
1
assert
len
(
timesteps
.
shape
)
==
1
...
@@ -1055,8 +1053,8 @@ class Decoder(nn.Module):
...
@@ -1055,8 +1053,8 @@ class Decoder(nn.Module):
class
VectorQuantizer
(
nn
.
Module
):
class
VectorQuantizer
(
nn
.
Module
):
"""
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix
avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
multiplications and allows for post-hoc remapping of indices.
"""
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# NOTE: due to a bug the beta term was applied to the wrong term. for
...
...
src/diffusers/schedulers/scheduling_ddim.py
View file @
26ea58d4
...
@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
...
@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
):
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
):
"""
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of
which defines the cumulative product of
(1-beta) over time from t = [0,1].
(1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
from 0 to 1 and
produces the cumulative product of (1-beta) up to that
produces the cumulative product of (1-beta) up to that part of the diffusion process.
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
prevent singularities.
"""
"""
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
26ea58d4
...
@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
...
@@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
):
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
max_beta
=
0.999
):
"""
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of
which defines the cumulative product of
(1-beta) over time from t = [0,1].
(1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
from 0 to 1 and
produces the cumulative product of (1-beta) up to that
produces the cumulative product of (1-beta) up to that part of the diffusion process.
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
prevent singularities.
"""
"""
...
@@ -144,16 +143,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -144,16 +143,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
if
timesteps
.
dim
()
!=
1
:
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
sqrt_one_minus_alpha_prod
=
self
.
match_shape
(
sqrt_one_minus_alpha_prod
,
original_samples
)
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment