Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
4e08e0ca
Commit
4e08e0ca
authored
Jun 27, 2022
by
Patrick von Platen
Browse files
merge
parents
af6c1439
07ff0abf
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
351 additions
and
136 deletions
+351
-136
examples/experimental/train_glide_text_to_image.py
examples/experimental/train_glide_text_to_image.py
+201
-0
examples/train_latent_text_to_image.py
examples/train_latent_text_to_image.py
+45
-31
examples/train_unconditional.py
examples/train_unconditional.py
+6
-3
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+3
-2
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+2
-15
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-33
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+2
-10
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+4
-32
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+4
-8
src/diffusers/schedulers/scheduling_utils.py
src/diffusers/schedulers/scheduling_utils.py
+28
-0
tests/test_layers_utils.py
tests/test_layers_utils.py
+51
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-2
No files found.
examples/experimental/train_glide_text_to_image.py
0 → 100644
View file @
4e08e0ca
import
argparse
import
os
import
torch
import
torch.nn.functional
as
F
import
bitsandbytes
as
bnb
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPMScheduler
,
Glide
,
GlideUNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
CenterCrop
,
Compose
,
InterpolationMode
,
Normalize
,
RandomHorizontalFlip
,
Resize
,
ToTensor
,
)
from
tqdm.auto
import
tqdm
logger
=
logging
.
get_logger
(
__name__
)
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
pipeline
=
Glide
.
from_pretrained
(
"fusing/glide-base"
)
model
=
pipeline
.
text_unet
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
bnb
.
optim
.
Adam8bit
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
[
Resize
(
args
.
resolution
,
interpolation
=
InterpolationMode
.
BILINEAR
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Normalize
([
0.5
],
[
0.5
]),
]
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
text_encoder
=
pipeline
.
text_encoder
.
eval
()
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
text_inputs
=
pipeline
.
tokenizer
(
examples
[
"caption"
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
text_inputs
=
text_inputs
.
input_ids
.
to
(
accelerator
.
device
)
with
torch
.
no_grad
():
text_embeddings
=
accelerator
.
unwrap_model
(
text_encoder
)(
text_inputs
).
last_hidden_state
return
{
"images"
:
images
,
"text_embeddings"
:
text_embeddings
}
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
lr_scheduler
=
get_scheduler
(
"linear"
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
model
,
text_encoder
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
text_encoder
,
optimizer
,
train_dataloader
,
lr_scheduler
)
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
# Train!
is_distributed
=
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
is_distributed
else
1
total_train_batch_size
=
args
.
batch_size
*
args
.
gradient_accumulation_steps
*
world_size
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"images"
]
batch_size
,
n_channels
,
height
,
width
=
clean_images
.
shape
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
batch_size
,),
device
=
clean_images
.
device
).
long
()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images
=
noise_scheduler
.
training_step
(
clean_images
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
model_output
=
model
(
noisy_images
,
timesteps
,
batch
[
"text_embeddings"
])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
n_channels
,
dim
=
1
)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out
=
torch
.
cat
([
model_output
.
detach
(),
model_var_values
],
dim
=
1
)
# predict the noise residual
loss
=
F
.
mse_loss
(
model_output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
optimizer
.
step
()
else
:
model_output
=
model
(
noisy_images
,
timesteps
,
batch
[
"text_embeddings"
])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
n_channels
,
dim
=
1
)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out
=
torch
.
cat
([
model_output
.
detach
(),
model_var_values
],
dim
=
1
)
# predict the noise residual
loss
=
F
.
mse_loss
(
model_output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
accelerator
.
wait_for_everyone
()
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
model
.
eval
()
with
torch
.
no_grad
():
pipeline
.
unet
=
accelerator
.
unwrap_model
(
model
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
"a clip art of a corgi"
,
generator
=
generator
,
num_upscale_inference_steps
=
50
)
# process image to PIL
image_processed
=
image
.
squeeze
(
0
)
image_processed
=
((
image_processed
+
1
)
*
127.5
).
round
().
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
)
# save image
test_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
:
04
d
}
.png"
)
# save the model
if
args
.
push_to_hub
:
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
accelerator
.
wait_for_everyone
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"fusing/dog_captions"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"glide-text2image"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_private_repo"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"no"
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args
=
parser
.
parse_args
()
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
if
env_local_rank
!=
-
1
and
env_local_rank
!=
args
.
local_rank
:
args
.
local_rank
=
env_local_rank
main
(
args
)
examples/train_latent_text_to_image.py
View file @
4e08e0ca
...
@@ -4,19 +4,19 @@ import os
...
@@ -4,19 +4,19 @@ import os
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
bitsandbytes
as
bnb
import
PIL.Image
import
PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetLDMModel
from
diffusers
import
DDPMScheduler
,
LatentDiffusion
,
UNetLDMModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
InterpolationMode
,
InterpolationMode
,
Lambda
,
Normalize
,
RandomHorizontalFlip
,
RandomHorizontalFlip
,
Resize
,
Resize
,
ToTensor
,
ToTensor
,
...
@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
...
@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
def
main
(
args
):
def
main
(
args
):
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
)
pipeline
=
LatentDiffusion
.
from_pretrained
(
"fusing/latent-diffusion-text2im-large"
)
pipeline
.
unet
=
None
# this model will be trained from scratch now
model
=
UNetLDMModel
(
model
=
UNetLDMModel
(
attention_resolutions
=
[
4
,
2
,
1
],
attention_resolutions
=
[
4
,
2
,
1
],
channel_mult
=
[
1
,
2
,
4
,
4
],
channel_mult
=
[
1
,
2
,
4
,
4
],
...
@@ -37,7 +39,7 @@ def main(args):
...
@@ -37,7 +39,7 @@ def main(args):
conv_resample
=
True
,
conv_resample
=
True
,
dims
=
2
,
dims
=
2
,
dropout
=
0
,
dropout
=
0
,
image_size
=
32
,
image_size
=
8
,
in_channels
=
4
,
in_channels
=
4
,
model_channels
=
320
,
model_channels
=
320
,
num_heads
=
8
,
num_heads
=
8
,
...
@@ -51,7 +53,7 @@ def main(args):
...
@@ -51,7 +53,7 @@ def main(args):
legacy
=
False
,
legacy
=
False
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
bnb
.
optim
.
Adam
8bit
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
augmentations
=
Compose
(
[
[
...
@@ -59,14 +61,22 @@ def main(args):
...
@@ -59,14 +61,22 @@ def main(args):
CenterCrop
(
args
.
resolution
),
CenterCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
RandomHorizontalFlip
(),
ToTensor
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Normalize
([
0.5
],
[
0.5
]
),
]
]
)
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
text_encoder
=
pipeline
.
bert
.
eval
()
vqvae
=
pipeline
.
vqvae
.
eval
()
def
transforms
(
examples
):
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
text_inputs
=
pipeline
.
tokenizer
(
examples
[
"caption"
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
text_embeddings
=
accelerator
.
unwrap_model
(
text_encoder
)(
text_inputs
.
input_ids
.
cpu
()).
last_hidden_state
images
=
1
/
0.18215
*
torch
.
stack
(
images
,
dim
=
0
)
latents
=
accelerator
.
unwrap_model
(
vqvae
).
encode
(
images
.
cpu
()).
mode
()
return
{
"images"
:
images
,
"text_embeddings"
:
text_embeddings
,
"latents"
:
latents
}
dataset
.
set_transform
(
transforms
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
...
@@ -78,9 +88,11 @@ def main(args):
...
@@ -78,9 +88,11 @@ def main(args):
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
text_encoder
,
vqvae
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
text_encoder
,
vqvae
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
text_encoder
=
text_encoder
.
cpu
()
vqvae
=
vqvae
.
cpu
()
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -98,29 +110,31 @@ def main(args):
...
@@ -98,29 +110,31 @@ def main(args):
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_
image
s
=
batch
[
"
input
"
]
clean_
latent
s
=
batch
[
"
latents
"
]
noise_samples
=
torch
.
randn
(
clean_
image
s
.
shape
).
to
(
clean_
image
s
.
device
)
noise_samples
=
torch
.
randn
(
clean_
latent
s
.
shape
).
to
(
clean_
latent
s
.
device
)
bsz
=
clean_
image
s
.
shape
[
0
]
bsz
=
clean_
latent
s
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_
image
s
.
device
).
long
()
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_
latent
s
.
device
).
long
()
# add noise onto the clean
image
s according to the noise magnitude at each timestep
# add noise onto the clean
latent
s according to the noise magnitude at each timestep
# (this is the forward diffusion process)
# (this is the forward diffusion process)
noisy_
image
s
=
noise_scheduler
.
training_step
(
clean_
image
s
,
noise_samples
,
timesteps
)
noisy_
latent
s
=
noise_scheduler
.
training_step
(
clean_
latent
s
,
noise_samples
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_
image
s
,
timesteps
)
output
=
model
(
noisy_
latent
s
,
timesteps
,
context
=
batch
[
"text_embeddings"
]
)
# predict the noise residual
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
optimizer
.
step
()
else
:
else
:
output
=
model
(
noisy_
image
s
,
timesteps
)
output
=
model
(
noisy_
latent
s
,
timesteps
,
context
=
batch
[
"text_embeddings"
]
)
# predict the noise residual
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
loss
=
loss
/
args
.
gradient_accumulation_steps
...
@@ -131,24 +145,25 @@ def main(args):
...
@@ -131,24 +145,25 @@ def main(args):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
global_step
+=
1
optimizer
.
step
()
accelerator
.
wait_for_everyone
()
if
is_distributed
:
torch
.
distributed
.
barrier
()
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
a
rgs
.
local_rank
in
[
-
1
,
0
]
:
if
a
ccelerator
.
is_main_process
:
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
unwrap_model
(
model
),
noise_scheduler
=
noise_scheduler
)
pipeline
.
unet
=
accelerator
.
unwrap_model
(
model
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
image
=
pipeline
(
[
"a clip art of a corgi"
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
# process image to PIL
# process image to PIL
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
(
image_processed
+
1.0
)
*
127.5
image_processed
=
image_processed
*
255.0
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_processed
=
image_processed
.
type
(
torch
.
uint8
).
numpy
()
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
...
@@ -162,20 +177,19 @@ def main(args):
...
@@ -162,20 +177,19 @@ def main(args):
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
else
:
else
:
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
is_distributed
:
accelerator
.
wait_for_everyone
()
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"
huggan/flowers-102-categorie
s"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"
fusing/dog_caption
s"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"
ddpm-model
"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"
ldm-text2image
"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
6
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
6
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
...
...
examples/train_unconditional.py
View file @
4e08e0ca
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
import
PIL.Image
import
PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
Pipeline
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
...
@@ -71,7 +71,7 @@ def main(args):
...
@@ -71,7 +71,7 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
1.0
,
power
=
3
/
4
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
args
.
ema_inv_gamma
,
power
=
args
.
ema_power
,
max_value
=
args
.
ema_max_decay
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
repo
=
init_git_repo
(
args
,
at_init
=
True
)
repo
=
init_git_repo
(
args
,
at_init
=
True
)
...
@@ -133,7 +133,7 @@ def main(args):
...
@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DDPM
(
pipeline
=
DDPM
Pipeline
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
noise_scheduler
=
noise_scheduler
)
)
...
@@ -172,6 +172,9 @@ if __name__ == "__main__":
...
@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_token"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--hub_model_id"
,
type
=
str
,
default
=
None
)
...
...
src/diffusers/models/resnet.py
View file @
4e08e0ca
...
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
...
@@ -64,7 +64,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
upsampling occurs in the inner-two dimensions.
"""
"""
def
__init__
(
self
,
channels
,
use_conv
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
def
__init__
(
self
,
channels
,
use_conv
=
False
,
use_conv_transpose
=
False
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
channels
=
channels
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
out_channels
=
out_channels
or
channels
...
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
...
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
self
.
use_conv_transpose
=
use_conv_transpose
self
.
use_conv_transpose
=
use_conv_transpose
if
use_conv_transpose
:
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
out_channels
,
4
,
2
,
1
)
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
...
@@ -207,6 +207,7 @@ class GradTTSUpsample(torch.nn.Module):
...
@@ -207,6 +207,7 @@ class GradTTSUpsample(torch.nn.Module):
return
self
.
conv
(
x
)
return
self
.
conv
(
x
)
# TODO (patil-suraj): needs test
class
Upsample1d
(
nn
.
Module
):
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/models/unet.py
View file @
4e08e0ca
...
@@ -31,6 +31,7 @@ from tqdm import tqdm
...
@@ -31,6 +31,7 @@ from tqdm import tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
nonlinearity
(
x
):
def
nonlinearity
(
x
):
...
@@ -42,20 +43,6 @@ def Normalize(in_channels):
...
@@ -42,20 +43,6 @@ def Normalize(in_channels):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
super
().
__init__
()
...
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
...
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up
.
block
=
block
up
.
block
=
block
up
.
attn
=
attn
up
.
attn
=
attn
if
i_level
!=
0
:
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
up
.
upsample
=
Upsample
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
...
...
src/diffusers/models/unet_glide.py
View file @
4e08e0ca
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
...
@@ -8,6 +8,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
def
convert_module_to_f16
(
l
):
def
convert_module_to_f16
(
l
):
...
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
...
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
...
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
...
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
src/diffusers/models/unet_grad_tts.py
View file @
4e08e0ca
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
class
Mish
(
torch
.
nn
.
Module
):
class
Mish
(
torch
.
nn
.
Module
):
...
@@ -10,15 +11,6 @@ class Mish(torch.nn.Module):
...
@@ -10,15 +11,6 @@ class Mish(torch.nn.Module):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Downsample
(
torch
.
nn
.
Module
):
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
super
(
Downsample
,
self
).
__init__
()
...
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
),
Upsample
(
dim_in
,
use_conv_transpose
=
True
),
]
]
)
)
)
)
...
...
src/diffusers/models/unet_ldm.py
View file @
4e08e0ca
...
@@ -10,6 +10,7 @@ import torch.nn.functional as F
...
@@ -10,6 +10,7 @@ import torch.nn.functional as F
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
.embeddings
import
get_timestep_embedding
from
.embeddings
import
get_timestep_embedding
from
.resnet
import
Upsample
# try:
# try:
...
@@ -403,35 +404,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -403,35 +404,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return
x
return
x
class
Upsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def
__init__
(
self
,
channels
,
use_conv
,
dims
=
2
,
out_channels
=
None
,
padding
=
1
):
super
().
__init__
()
self
.
channels
=
channels
self
.
out_channels
=
out_channels
or
channels
self
.
use_conv
=
use_conv
self
.
dims
=
dims
if
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
padding
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
"""
"""
A downsampling layer with an optional convolution.
A downsampling layer with an optional convolution.
...
@@ -506,8 +478,8 @@ class ResBlock(TimestepBlock):
...
@@ -506,8 +478,8 @@ class ResBlock(TimestepBlock):
self
.
updown
=
up
or
down
self
.
updown
=
up
or
down
if
up
:
if
up
:
self
.
h_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
self
.
x_upd
=
Upsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Upsample
(
channels
,
use_conv
=
False
,
dims
=
dims
)
elif
down
:
elif
down
:
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
h_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
self
.
x_upd
=
Downsample
(
channels
,
False
,
dims
)
...
@@ -974,7 +946,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -974,7 +946,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
up
=
True
,
up
=
True
,
)
)
if
resblock_updown
if
resblock_updown
else
Upsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
else
Upsample
(
ch
,
use_conv
=
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ds
//=
2
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
4e08e0ca
...
@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
pred_prev_sample
return
pred_prev_sample
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
def
training_step
(
self
,
original_samples
:
torch
.
Tensor
,
noise
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
):
if
timesteps
.
dim
()
!=
1
:
raise
ValueError
(
"`timesteps` must be a 1D tensor"
)
device
=
original_samples
.
device
batch_size
=
original_samples
.
shape
[
0
]
timesteps
=
timesteps
.
reshape
(
batch_size
,
1
,
1
,
1
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
match_shape
(
sqrt_alpha_prod
,
original_samples
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
noisy_samples
=
sqrt_alpha_prod
.
to
(
device
)
*
original_samples
+
sqrt_one_minus_alpha_prod
.
to
(
device
)
*
noise
sqrt_one_minus_alpha_prod
=
self
.
match_shape
(
sqrt_one_minus_alpha_prod
,
original_samples
)
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
return
noisy_samples
def
__len__
(
self
):
def
__len__
(
self
):
...
...
src/diffusers/schedulers/scheduling_utils.py
View file @
4e08e0ca
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
typing
import
Union
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
...
@@ -50,3 +52,29 @@ class SchedulerMixin:
...
@@ -50,3 +52,29 @@ class SchedulerMixin:
return
torch
.
log
(
tensor
)
return
torch
.
log
(
tensor
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
def
match_shape
(
self
,
values
:
Union
[
np
.
ndarray
,
torch
.
Tensor
],
broadcast_array
:
Union
[
np
.
ndarray
,
torch
.
Tensor
]
):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
values
=
values
.
flatten
()
while
len
(
values
.
shape
)
<
len
(
broadcast_array
.
shape
):
values
=
values
[...,
None
]
if
tensor_format
==
"pt"
:
values
=
values
.
to
(
broadcast_array
.
device
)
return
values
tests/test_layers_utils.py
View file @
4e08e0ca
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
import
torch
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.resnet
import
Upsample
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
...
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
1e-3
,
1e-3
,
)
)
class
UpsampleBlockTests
(
unittest
.
TestCase
):
def
test_upsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.2173
,
-
1.2079
,
-
1.2079
,
0.2952
,
1.1254
,
1.1254
,
0.2952
,
1.1254
,
1.1254
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.7145
,
1.3773
,
0.3492
,
0.8448
,
1.0839
,
-
0.3341
,
0.5956
,
0.1250
,
-
0.4841
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
64
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
64
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.2703
,
0.1656
,
-
0.2538
,
-
0.0553
,
-
0.2984
,
0.1044
,
0.1155
,
0.2579
,
0.7755
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_transpose
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
False
,
use_conv_transpose
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.3028
,
-
0.1582
,
0.0071
,
0.0350
,
-
0.4799
,
-
0.1139
,
0.1056
,
-
0.1153
,
-
0.1046
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
tests/test_modeling_utils.py
View file @
4e08e0ca
...
@@ -21,7 +21,7 @@ import unittest
...
@@ -21,7 +21,7 @@ import unittest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
diffusers
import
(
from
diffusers
import
(
# GradTTSPipeline,
BDDMPipeline
,
BDDMPipeline
,
DDIMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDIMScheduler
,
...
@@ -30,7 +30,6 @@ from diffusers import (
...
@@ -30,7 +30,6 @@ from diffusers import (
GlidePipeline
,
GlidePipeline
,
GlideSuperResUNetModel
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
NCSNpp
,
NCSNpp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment