Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8b975882
Commit
8b975882
authored
Jun 15, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
32c55673
850d4345
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
675 additions
and
42 deletions
+675
-42
README.md
README.md
+3
-2
examples/train_ddpm.py
examples/train_ddpm.py
+36
-25
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+233
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+385
-0
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+1
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+15
-14
No files found.
README.md
View file @
8b975882
...
...
@@ -200,13 +200,14 @@ image_pil.save("test.png")
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
```
python
from
diffusers
import
DiffusionPipeline
ldm
=
DiffusionPipeline
.
from_pretrained
(
"fusing/latent-diffusion-text2im-large"
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
generator
=
torch
.
manual_seed
(
42
)
prompt
=
"A painting of a squirrel eating a burger"
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
...
...
examples/train
ing
_ddpm.py
→
examples/train_ddpm.py
View file @
8b975882
import
argparse
import
os
import
torch
import
PIL.Image
import
argparse
import
torch.nn.functional
as
F
import
PIL.Image
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
...
...
@@ -31,44 +31,40 @@ def main(args):
dropout
=
0.0
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
resolution
=
64
,
resolution
=
args
.
resolution
,
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1e-4
)
num_epochs
=
100
batch_size
=
16
gradient_accumulation_steps
=
1
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
augmentations
=
Compose
(
[
Resize
(
64
,
interpolation
=
InterpolationMode
.
BILINEAR
),
RandomCrop
(
64
),
Resize
(
args
.
resolution
,
interpolation
=
InterpolationMode
.
BILINEAR
),
RandomCrop
(
args
.
resolution
),
RandomHorizontalFlip
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
)
dataset
=
load_dataset
(
"huggan/pokemon"
,
split
=
"train"
)
dataset
=
load_dataset
(
args
.
dataset
,
split
=
"train"
)
def
transforms
(
examples
):
images
=
[
augmentations
(
image
.
convert
(
"RGB"
))
for
image
in
examples
[
"image"
]]
return
{
"input"
:
images
}
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
True
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
lr_scheduler
=
get_linear_schedule_with_warmup
(
optimizer
=
optimizer
,
num_warmup_steps
=
500
,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
(
len
(
train_dataloader
)
*
args
.
num_epochs
)
//
args
.
gradient_accumulation_steps
,
)
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
for
epoch
in
range
(
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
with
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
as
pbar
:
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
...
...
@@ -84,14 +80,15 @@ def main(args):
noise_samples
[
idx
]
=
noise
noisy_images
[
idx
]
=
noise_scheduler
.
forward_step
(
clean_images
[
idx
],
noise
,
timesteps
[
idx
])
if
step
%
gradient_accumulation_steps
!=
0
:
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
no_sync
(
model
):
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise
# predict the noise
residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
...
...
@@ -103,13 +100,18 @@ def main(args):
optimizer
.
step
()
# Generate a sample image for visual inspection
torch
.
distributed
.
barrier
()
if
args
.
local_rank
in
[
-
1
,
0
]:
model
.
eval
()
with
torch
.
no_grad
():
pipeline
=
DDPM
(
unet
=
model
.
module
,
noise_scheduler
=
noise_scheduler
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
pipeline
=
DDPM
(
unet
=
model
.
module
,
noise_scheduler
=
noise_scheduler
)
else
:
pipeline
=
DDPM
(
unet
=
model
,
noise_scheduler
=
noise_scheduler
)
pipeline
.
save_pretrained
(
args
.
output_path
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
image
=
pipeline
(
generator
=
generator
)
...
...
@@ -120,22 +122,31 @@ def main(args):
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
pipeline
.
save_pretrained
(
"./pokemon-ddpm"
)
image_pil
.
save
(
f
"./pokemon-ddpm/test_
{
epoch
}
.png"
)
test_dir
=
os
.
path
.
join
(
args
.
output_path
,
"test_samples"
)
os
.
makedirs
(
test_dir
,
exist_ok
=
True
)
image_pil
.
save
(
f
"
{
test_dir
}
/
{
epoch
}
.png"
)
torch
.
distributed
.
barrier
()
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of training script."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of
a
training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"huggan/flowers-102-categories"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"ddpm-model"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"no"
,
choices
=
[
"no"
,
"fp16"
,
"bf16"
],
help
=
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
,
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
,
)
args
=
parser
.
parse_args
()
...
...
src/diffusers/__init__.py
View file @
8b975882
...
...
@@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
...
...
src/diffusers/models/__init__.py
View file @
8b975882
...
...
@@ -19,3 +19,4 @@
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_grad_tts
import
UNetGradTTSModel
\ No newline at end of file
src/diffusers/models/unet_grad_tts.py
0 → 100644
View file @
8b975882
import
math
import
torch
try
:
from
einops
import
rearrange
,
repeat
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Rezero
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Rezero
,
self
).
__init__
()
self
.
fn
=
fn
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
):
return
self
.
fn
(
x
)
*
self
.
g
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
())
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlock
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
forward
(
self
,
x
,
mask
,
time_emb
):
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b (qkv heads c) h w -> qkv b heads c (h w)'
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'bhdn,bhen->bhde'
,
k
,
v
)
out
=
torch
.
einsum
(
'bhde,bhdn->bhen'
,
context
,
q
)
out
=
rearrange
(
out
,
'b heads c (h w) -> b (heads c) h w'
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
class
Residual
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Residual
,
self
).
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
output
=
self
.
fn
(
x
,
*
args
,
**
kwargs
)
+
x
return
output
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
self
.
register
(
dim
=
dim
,
dim_mults
=
dim_mults
,
groups
=
groups
,
n_spks
=
n_spks
,
spk_emb_dim
=
spk_emb_dim
,
n_feats
=
n_feats
,
pe_scale
=
pe_scale
)
self
.
dim
=
dim
self
.
dim_mults
=
dim_mults
self
.
groups
=
groups
self
.
n_spks
=
n_spks
if
not
isinstance
(
n_spks
,
type
(
None
))
else
1
self
.
spk_emb_dim
=
spk_emb_dim
self
.
pe_scale
=
pe_scale
if
n_spks
>
1
:
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
))
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
self
.
downs
=
torch
.
nn
.
ModuleList
([])
self
.
ups
=
torch
.
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
([
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
()]))
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
([
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
)]))
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
t
=
self
.
time_pos_emb
(
t
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
x
=
torch
.
stack
([
mu
,
x
],
1
)
else
:
s
=
s
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])
x
=
torch
.
stack
([
mu
,
x
,
s
],
1
)
mask
=
mask
.
unsqueeze
(
1
)
hiddens
=
[]
masks
=
[
mask
]
for
resnet1
,
resnet2
,
attn
,
downsample
in
self
.
downs
:
mask_down
=
masks
[
-
1
]
x
=
resnet1
(
x
,
mask_down
,
t
)
x
=
resnet2
(
x
,
mask_down
,
t
)
x
=
attn
(
x
)
hiddens
.
append
(
x
)
x
=
downsample
(
x
*
mask_down
)
masks
.
append
(
mask_down
[:,
:,
:,
::
2
])
masks
=
masks
[:
-
1
]
mask_mid
=
masks
[
-
1
]
x
=
self
.
mid_block1
(
x
,
mask_mid
,
t
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid_block2
(
x
,
mask_mid
,
t
)
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
mask_up
=
masks
.
pop
()
x
=
torch
.
cat
((
x
,
hiddens
.
pop
()),
dim
=
1
)
x
=
resnet1
(
x
,
mask_up
,
t
)
x
=
resnet2
(
x
,
mask_up
,
t
)
x
=
attn
(
x
)
x
=
upsample
(
x
*
mask_up
)
x
=
self
.
final_block
(
x
,
mask
)
output
=
self
.
final_conv
(
x
*
mask
)
return
(
output
*
mask
).
squeeze
(
1
)
\ No newline at end of file
src/diffusers/pipelines/pipeline_grad_tts.py
0 → 100644
View file @
8b975882
""" from https://github.com/jaywalnut310/glow-tts """
import
math
import
torch
from
torch
import
nn
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
int
(
max_length
),
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
def
fix_len_compatibility
(
length
,
num_downsamplings_in_unet
=
2
):
while
True
:
if
length
%
(
2
**
num_downsamplings_in_unet
)
==
0
:
return
length
length
+=
1
def
convert_pad_shape
(
pad_shape
):
l
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
l
for
item
in
sublist
]
return
pad_shape
def
generate_path
(
duration
,
mask
):
device
=
duration
.
device
b
,
t_x
,
t_y
=
mask
.
shape
cum_duration
=
torch
.
cumsum
(
duration
,
1
)
path
=
torch
.
zeros
(
b
,
t_x
,
t_y
,
dtype
=
mask
.
dtype
).
to
(
device
=
device
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
*
mask
return
path
def
duration_loss
(
logw
,
logw_
,
lengths
):
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
return
loss
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
eps
=
1e-4
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
channels
=
channels
self
.
eps
=
eps
self
.
gamma
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
channels
))
self
.
beta
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
channels
))
def
forward
(
self
,
x
):
n_dims
=
len
(
x
.
shape
)
mean
=
torch
.
mean
(
x
,
1
,
keepdim
=
True
)
variance
=
torch
.
mean
((
x
-
mean
)
**
2
,
1
,
keepdim
=
True
)
x
=
(
x
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
shape
=
[
1
,
-
1
]
+
[
1
]
*
(
n_dims
-
2
)
x
=
x
*
self
.
gamma
.
view
(
*
shape
)
+
self
.
beta
.
view
(
*
shape
)
return
x
class
ConvReluNorm
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_channels
,
out_channels
,
kernel_size
,
n_layers
,
p_dropout
):
super
(
ConvReluNorm
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
hidden_channels
=
hidden_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
n_layers
=
n_layers
self
.
p_dropout
=
p_dropout
self
.
conv_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers
=
torch
.
nn
.
ModuleList
()
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
in_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
relu_drop
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
p_dropout
))
for
_
in
range
(
n_layers
-
1
):
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
hidden_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
proj
=
torch
.
nn
.
Conv1d
(
hidden_channels
,
out_channels
,
1
)
self
.
proj
.
weight
.
data
.
zero_
()
self
.
proj
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
x_mask
):
x_org
=
x
for
i
in
range
(
self
.
n_layers
):
x
=
self
.
conv_layers
[
i
](
x
*
x_mask
)
x
=
self
.
norm_layers
[
i
](
x
)
x
=
self
.
relu_drop
(
x
)
x
=
x_org
+
self
.
proj
(
x
)
return
x
*
x_mask
class
DurationPredictor
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
filter_channels
,
kernel_size
,
p_dropout
):
super
(
DurationPredictor
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
filter_channels
=
filter_channels
self
.
p_dropout
=
p_dropout
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_1
=
LayerNorm
(
filter_channels
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_2
=
LayerNorm
(
filter_channels
)
self
.
proj
=
torch
.
nn
.
Conv1d
(
filter_channels
,
1
,
1
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
norm_1
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
norm_2
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
proj
(
x
*
x_mask
)
return
x
*
x_mask
class
MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
out_channels
,
n_heads
,
window_size
=
None
,
heads_share
=
True
,
p_dropout
=
0.0
,
proximal_bias
=
False
,
proximal_init
=
False
):
super
(
MultiHeadAttention
,
self
).
__init__
()
assert
channels
%
n_heads
==
0
self
.
channels
=
channels
self
.
out_channels
=
out_channels
self
.
n_heads
=
n_heads
self
.
window_size
=
window_size
self
.
heads_share
=
heads_share
self
.
proximal_bias
=
proximal_bias
self
.
p_dropout
=
p_dropout
self
.
attn
=
None
self
.
k_channels
=
channels
//
n_heads
self
.
conv_q
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_k
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_v
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
if
window_size
is
not
None
:
n_heads_rel
=
1
if
heads_share
else
n_heads
rel_stddev
=
self
.
k_channels
**-
0.5
self
.
emb_rel_k
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
emb_rel_v
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
conv_o
=
torch
.
nn
.
Conv1d
(
channels
,
out_channels
,
1
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_q
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_k
.
weight
)
if
proximal_init
:
self
.
conv_k
.
weight
.
data
.
copy_
(
self
.
conv_q
.
weight
.
data
)
self
.
conv_k
.
bias
.
data
.
copy_
(
self
.
conv_q
.
bias
.
data
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_v
.
weight
)
def
forward
(
self
,
x
,
c
,
attn_mask
=
None
):
q
=
self
.
conv_q
(
x
)
k
=
self
.
conv_k
(
c
)
v
=
self
.
conv_v
(
c
)
x
,
self
.
attn
=
self
.
attention
(
q
,
k
,
v
,
mask
=
attn_mask
)
x
=
self
.
conv_o
(
x
)
return
x
def
attention
(
self
,
query
,
key
,
value
,
mask
=
None
):
b
,
d
,
t_s
,
t_t
=
(
*
key
.
size
(),
query
.
size
(
2
))
query
=
query
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_t
).
transpose
(
2
,
3
)
key
=
key
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
value
=
value
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
scores
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
k_channels
)
if
self
.
window_size
is
not
None
:
assert
t_s
==
t_t
,
"Relative attention is only available for self-attention."
key_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_k
,
t_s
)
rel_logits
=
self
.
_matmul_with_relative_keys
(
query
,
key_relative_embeddings
)
rel_logits
=
self
.
_relative_position_to_absolute_position
(
rel_logits
)
scores_local
=
rel_logits
/
math
.
sqrt
(
self
.
k_channels
)
scores
=
scores
+
scores_local
if
self
.
proximal_bias
:
assert
t_s
==
t_t
,
"Proximal bias is only available for self-attention."
scores
=
scores
+
self
.
_attention_bias_proximal
(
t_s
).
to
(
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e4
)
p_attn
=
torch
.
nn
.
functional
.
softmax
(
scores
,
dim
=-
1
)
p_attn
=
self
.
drop
(
p_attn
)
output
=
torch
.
matmul
(
p_attn
,
value
)
if
self
.
window_size
is
not
None
:
relative_weights
=
self
.
_absolute_position_to_relative_position
(
p_attn
)
value_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_v
,
t_s
)
output
=
output
+
self
.
_matmul_with_relative_values
(
relative_weights
,
value_relative_embeddings
)
output
=
output
.
transpose
(
2
,
3
).
contiguous
().
view
(
b
,
d
,
t_t
)
return
output
,
p_attn
def
_matmul_with_relative_values
(
self
,
x
,
y
):
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
))
return
ret
def
_matmul_with_relative_keys
(
self
,
x
,
y
):
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
).
transpose
(
-
2
,
-
1
))
return
ret
def
_get_relative_embeddings
(
self
,
relative_embeddings
,
length
):
pad_length
=
max
(
length
-
(
self
.
window_size
+
1
),
0
)
slice_start_position
=
max
((
self
.
window_size
+
1
)
-
length
,
0
)
slice_end_position
=
slice_start_position
+
2
*
length
-
1
if
pad_length
>
0
:
padded_relative_embeddings
=
torch
.
nn
.
functional
.
pad
(
relative_embeddings
,
convert_pad_shape
([[
0
,
0
],
[
pad_length
,
pad_length
],
[
0
,
0
]]))
else
:
padded_relative_embeddings
=
relative_embeddings
used_relative_embeddings
=
padded_relative_embeddings
[:,
slice_start_position
:
slice_end_position
]
return
used_relative_embeddings
def
_relative_position_to_absolute_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
0
],[
0
,
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
*
2
*
length
])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
length
-
1
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
+
1
,
2
*
length
-
1
])[:,
:,
:
length
,
length
-
1
:]
return
x_final
def
_absolute_position_to_relative_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
**
2
+
length
*
(
length
-
1
)])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
length
,
0
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
,
2
*
length
])[:,:,:,
1
:]
return
x_final
def
_attention_bias_proximal
(
self
,
length
):
r
=
torch
.
arange
(
length
,
dtype
=
torch
.
float32
)
diff
=
torch
.
unsqueeze
(
r
,
0
)
-
torch
.
unsqueeze
(
r
,
1
)
return
torch
.
unsqueeze
(
torch
.
unsqueeze
(
-
torch
.
log1p
(
torch
.
abs
(
diff
)),
0
),
0
)
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
0.0
):
super
(
FFN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
filter_channels
=
filter_channels
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
x
*
x_mask
)
return
x
*
x_mask
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
kernel_size
=
1
,
p_dropout
=
0.0
,
window_size
=
None
,
**
kwargs
):
super
(
Encoder
,
self
).
__init__
()
self
.
hidden_channels
=
hidden_channels
self
.
filter_channels
=
filter_channels
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
attn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_1
=
torch
.
nn
.
ModuleList
()
self
.
ffn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_2
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
self
.
n_layers
):
self
.
attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
window_size
=
window_size
,
p_dropout
=
p_dropout
))
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
ffn_layers
.
append
(
FFN
(
hidden_channels
,
hidden_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
))
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
def
forward
(
self
,
x
,
x_mask
):
attn_mask
=
x_mask
.
unsqueeze
(
2
)
*
x_mask
.
unsqueeze
(
-
1
)
for
i
in
range
(
self
.
n_layers
):
x
=
x
*
x_mask
y
=
self
.
attn_layers
[
i
](
x
,
x
,
attn_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_1
[
i
](
x
+
y
)
y
=
self
.
ffn_layers
[
i
](
x
,
x_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_2
[
i
](
x
+
y
)
x
=
x
*
x_mask
return
x
class
TextEncoder
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
n_vocab
,
n_feats
,
n_channels
,
filter_channels
,
filter_channels_dp
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
None
,
spk_emb_dim
=
64
,
n_spks
=
1
):
super
(
TextEncoder
,
self
).
__init__
()
self
.
register
(
n_vocab
=
n_vocab
,
n_feats
=
n_feats
,
n_channels
=
n_channels
,
filter_channels
=
filter_channels
,
filter_channels_dp
=
filter_channels_dp
,
n_heads
=
n_heads
,
n_layers
=
n_layers
,
kernel_size
=
kernel_size
,
p_dropout
=
p_dropout
,
window_size
=
window_size
,
spk_emb_dim
=
spk_emb_dim
,
n_spks
=
n_spks
)
self
.
n_vocab
=
n_vocab
self
.
n_feats
=
n_feats
self
.
n_channels
=
n_channels
self
.
filter_channels
=
filter_channels
self
.
filter_channels_dp
=
filter_channels_dp
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
spk_emb_dim
=
spk_emb_dim
self
.
n_spks
=
n_spks
self
.
emb
=
torch
.
nn
.
Embedding
(
n_vocab
,
n_channels
)
torch
.
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
0.0
,
n_channels
**-
0.5
)
self
.
prenet
=
ConvReluNorm
(
n_channels
,
n_channels
,
n_channels
,
kernel_size
=
5
,
n_layers
=
3
,
p_dropout
=
0.5
)
self
.
encoder
=
Encoder
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
window_size
)
self
.
proj_m
=
torch
.
nn
.
Conv1d
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
n_feats
,
1
)
self
.
proj_w
=
DurationPredictor
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels_dp
,
kernel_size
,
p_dropout
)
def
forward
(
self
,
x
,
x_lengths
,
spk
=
None
):
x
=
self
.
emb
(
x
)
*
math
.
sqrt
(
self
.
n_channels
)
x
=
torch
.
transpose
(
x
,
1
,
-
1
)
x_mask
=
torch
.
unsqueeze
(
sequence_mask
(
x_lengths
,
x
.
size
(
2
)),
1
).
to
(
x
.
dtype
)
x
=
self
.
prenet
(
x
,
x_mask
)
if
self
.
n_spks
>
1
:
x
=
torch
.
cat
([
x
,
spk
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])],
dim
=
1
)
x
=
self
.
encoder
(
x
,
x_mask
)
mu
=
self
.
proj_m
(
x
)
*
x_mask
x_dp
=
torch
.
detach
(
x
)
logw
=
self
.
proj_w
(
x_dp
,
x_mask
)
return
mu
,
logw
,
x_mask
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
8b975882
...
...
@@ -943,7 +943,7 @@ class LatentDiffusion(DiffusionPipeline):
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
)
to
(
image
.
device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
)
.
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
...
...
tests/test_modeling_utils.py
View file @
8b975882
...
...
@@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
...
...
@@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase):
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment