Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
850d4345
Commit
850d4345
authored
Jun 15, 2022
by
anton-l
Browse files
Merge remote-tracking branch 'origin/main'
parents
cfe6eb16
f7d91f8b
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
636 additions
and
1049 deletions
+636
-1049
README.md
README.md
+4
-3
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+233
-0
src/diffusers/pipelines/configuration_ldmbert.py
src/diffusers/pipelines/configuration_ldmbert.py
+0
-146
src/diffusers/pipelines/modeling_vae.py
src/diffusers/pipelines/modeling_vae.py
+0
-859
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+385
-0
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+12
-41
No files found.
README.md
View file @
850d4345
...
...
@@ -166,13 +166,14 @@ image_pil.save("test.png")
#### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
```
python
from
diffusers
import
DiffusionPipeline
ldm
=
DiffusionPipeline
.
from_pretrained
(
"fusing/latent-diffusion-text2im-large"
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
generator
=
torch
.
manual_seed
(
42
)
prompt
=
"A painting of a squirrel eating a burger"
image
=
ldm
([
prompt
],
generator
=
generator
,
eta
=
0.3
,
guidance_scale
=
6.0
,
num_inference_steps
=
50
)
...
...
@@ -197,7 +198,7 @@ from diffusers import BDDM, DiffusionPipeline
torch_device
=
"cuda"
# load the BDDM pipeline
bddm
=
DiffusionPipeline
.
from_pretrained
(
"fusing/diffwave-vocoder"
)
bddm
=
DiffusionPipeline
.
from_pretrained
(
"fusing/diffwave-vocoder
-ljspeech
"
)
# load tacotron2 to get the mel spectograms
tacotron2
=
torch
.
hub
.
load
(
'NVIDIA/DeepLearningExamples:torchhub'
,
'nvidia_tacotron2'
,
model_math
=
'fp16'
)
...
...
src/diffusers/__init__.py
View file @
850d4345
...
...
@@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
...
...
src/diffusers/models/__init__.py
View file @
850d4345
...
...
@@ -19,3 +19,4 @@
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_grad_tts
import
UNetGradTTSModel
\ No newline at end of file
src/diffusers/models/unet_grad_tts.py
0 → 100644
View file @
850d4345
import
math
import
torch
try
:
from
einops
import
rearrange
,
repeat
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Rezero
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Rezero
,
self
).
__init__
()
self
.
fn
=
fn
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
):
return
self
.
fn
(
x
)
*
self
.
g
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
())
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlock
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
forward
(
self
,
x
,
mask
,
time_emb
):
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b (qkv heads c) h w -> qkv b heads c (h w)'
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'bhdn,bhen->bhde'
,
k
,
v
)
out
=
torch
.
einsum
(
'bhde,bhdn->bhen'
,
context
,
q
)
out
=
rearrange
(
out
,
'b heads c (h w) -> b (heads c) h w'
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
class
Residual
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Residual
,
self
).
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
output
=
self
.
fn
(
x
,
*
args
,
**
kwargs
)
+
x
return
output
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
self
.
register
(
dim
=
dim
,
dim_mults
=
dim_mults
,
groups
=
groups
,
n_spks
=
n_spks
,
spk_emb_dim
=
spk_emb_dim
,
n_feats
=
n_feats
,
pe_scale
=
pe_scale
)
self
.
dim
=
dim
self
.
dim_mults
=
dim_mults
self
.
groups
=
groups
self
.
n_spks
=
n_spks
if
not
isinstance
(
n_spks
,
type
(
None
))
else
1
self
.
spk_emb_dim
=
spk_emb_dim
self
.
pe_scale
=
pe_scale
if
n_spks
>
1
:
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
))
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
self
.
downs
=
torch
.
nn
.
ModuleList
([])
self
.
ups
=
torch
.
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
([
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
()]))
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
([
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
)]))
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
t
=
self
.
time_pos_emb
(
t
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
x
=
torch
.
stack
([
mu
,
x
],
1
)
else
:
s
=
s
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])
x
=
torch
.
stack
([
mu
,
x
,
s
],
1
)
mask
=
mask
.
unsqueeze
(
1
)
hiddens
=
[]
masks
=
[
mask
]
for
resnet1
,
resnet2
,
attn
,
downsample
in
self
.
downs
:
mask_down
=
masks
[
-
1
]
x
=
resnet1
(
x
,
mask_down
,
t
)
x
=
resnet2
(
x
,
mask_down
,
t
)
x
=
attn
(
x
)
hiddens
.
append
(
x
)
x
=
downsample
(
x
*
mask_down
)
masks
.
append
(
mask_down
[:,
:,
:,
::
2
])
masks
=
masks
[:
-
1
]
mask_mid
=
masks
[
-
1
]
x
=
self
.
mid_block1
(
x
,
mask_mid
,
t
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid_block2
(
x
,
mask_mid
,
t
)
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
mask_up
=
masks
.
pop
()
x
=
torch
.
cat
((
x
,
hiddens
.
pop
()),
dim
=
1
)
x
=
resnet1
(
x
,
mask_up
,
t
)
x
=
resnet2
(
x
,
mask_up
,
t
)
x
=
attn
(
x
)
x
=
upsample
(
x
*
mask_up
)
x
=
self
.
final_block
(
x
,
mask
)
output
=
self
.
final_conv
(
x
*
mask
)
return
(
output
*
mask
).
squeeze
(
1
)
\ No newline at end of file
src/diffusers/pipelines/configuration_ldmbert.py
deleted
100644 → 0
View file @
cfe6eb16
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LDMBERT model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"ldm-bert"
:
"https://huggingface.co/ldm-bert/resolve/main/config.json"
,
}
class
LDMBertConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`LDMBertModel`]. It is used to instantiate a
LDMBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the LDMBERT
[facebook/ldmbert-large](https://huggingface.co/facebook/ldmbert-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the LDMBERT model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LDMBertModel`] or [`TFLDMBertModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels: (`int`, *optional*, defaults to 3):
The number of labels to use in [`LDMBertForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import LDMBertModel, LDMBertConfig
>>> # Initializing a LDMBERT facebook/ldmbert-large style configuration
>>> configuration = LDMBertConfig()
>>> # Initializing a model from the facebook/ldmbert-large style configuration
>>> model = LDMBertModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"ldmbert"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"num_attention_heads"
:
"encoder_attention_heads"
,
"hidden_size"
:
"d_model"
}
def
__init__
(
self
,
vocab_size
=
30522
,
max_position_embeddings
=
77
,
encoder_layers
=
32
,
encoder_ffn_dim
=
5120
,
encoder_attention_heads
=
8
,
head_dim
=
64
,
encoder_layerdrop
=
0.0
,
activation_function
=
"gelu"
,
d_model
=
1280
,
dropout
=
0.1
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
init_std
=
0.02
,
classifier_dropout
=
0.0
,
scale_embedding
=
False
,
use_cache
=
True
,
pad_token_id
=
0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
d_model
=
d_model
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
head_dim
=
head_dim
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_dropout
=
activation_dropout
self
.
activation_function
=
activation_function
self
.
init_std
=
init_std
self
.
encoder_layerdrop
=
encoder_layerdrop
self
.
classifier_dropout
=
classifier_dropout
self
.
use_cache
=
use_cache
self
.
num_hidden_layers
=
encoder_layers
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
super
().
__init__
(
pad_token_id
=
pad_token_id
,
**
kwargs
)
src/diffusers/pipelines/modeling_vae.py
deleted
100644 → 0
View file @
cfe6eb16
This diff is collapsed.
Click to expand it.
src/diffusers/pipelines/pipeline_grad_tts.py
0 → 100644
View file @
850d4345
""" from https://github.com/jaywalnut310/glow-tts """
import
math
import
torch
from
torch
import
nn
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
int
(
max_length
),
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
def
fix_len_compatibility
(
length
,
num_downsamplings_in_unet
=
2
):
while
True
:
if
length
%
(
2
**
num_downsamplings_in_unet
)
==
0
:
return
length
length
+=
1
def
convert_pad_shape
(
pad_shape
):
l
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
l
for
item
in
sublist
]
return
pad_shape
def
generate_path
(
duration
,
mask
):
device
=
duration
.
device
b
,
t_x
,
t_y
=
mask
.
shape
cum_duration
=
torch
.
cumsum
(
duration
,
1
)
path
=
torch
.
zeros
(
b
,
t_x
,
t_y
,
dtype
=
mask
.
dtype
).
to
(
device
=
device
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
*
mask
return
path
def
duration_loss
(
logw
,
logw_
,
lengths
):
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
return
loss
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
eps
=
1e-4
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
channels
=
channels
self
.
eps
=
eps
self
.
gamma
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
channels
))
self
.
beta
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
channels
))
def
forward
(
self
,
x
):
n_dims
=
len
(
x
.
shape
)
mean
=
torch
.
mean
(
x
,
1
,
keepdim
=
True
)
variance
=
torch
.
mean
((
x
-
mean
)
**
2
,
1
,
keepdim
=
True
)
x
=
(
x
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
shape
=
[
1
,
-
1
]
+
[
1
]
*
(
n_dims
-
2
)
x
=
x
*
self
.
gamma
.
view
(
*
shape
)
+
self
.
beta
.
view
(
*
shape
)
return
x
class
ConvReluNorm
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_channels
,
out_channels
,
kernel_size
,
n_layers
,
p_dropout
):
super
(
ConvReluNorm
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
hidden_channels
=
hidden_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
n_layers
=
n_layers
self
.
p_dropout
=
p_dropout
self
.
conv_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers
=
torch
.
nn
.
ModuleList
()
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
in_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
relu_drop
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
p_dropout
))
for
_
in
range
(
n_layers
-
1
):
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
hidden_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
proj
=
torch
.
nn
.
Conv1d
(
hidden_channels
,
out_channels
,
1
)
self
.
proj
.
weight
.
data
.
zero_
()
self
.
proj
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
x_mask
):
x_org
=
x
for
i
in
range
(
self
.
n_layers
):
x
=
self
.
conv_layers
[
i
](
x
*
x_mask
)
x
=
self
.
norm_layers
[
i
](
x
)
x
=
self
.
relu_drop
(
x
)
x
=
x_org
+
self
.
proj
(
x
)
return
x
*
x_mask
class
DurationPredictor
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
filter_channels
,
kernel_size
,
p_dropout
):
super
(
DurationPredictor
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
filter_channels
=
filter_channels
self
.
p_dropout
=
p_dropout
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_1
=
LayerNorm
(
filter_channels
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_2
=
LayerNorm
(
filter_channels
)
self
.
proj
=
torch
.
nn
.
Conv1d
(
filter_channels
,
1
,
1
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
norm_1
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
norm_2
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
proj
(
x
*
x_mask
)
return
x
*
x_mask
class
MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
out_channels
,
n_heads
,
window_size
=
None
,
heads_share
=
True
,
p_dropout
=
0.0
,
proximal_bias
=
False
,
proximal_init
=
False
):
super
(
MultiHeadAttention
,
self
).
__init__
()
assert
channels
%
n_heads
==
0
self
.
channels
=
channels
self
.
out_channels
=
out_channels
self
.
n_heads
=
n_heads
self
.
window_size
=
window_size
self
.
heads_share
=
heads_share
self
.
proximal_bias
=
proximal_bias
self
.
p_dropout
=
p_dropout
self
.
attn
=
None
self
.
k_channels
=
channels
//
n_heads
self
.
conv_q
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_k
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_v
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
if
window_size
is
not
None
:
n_heads_rel
=
1
if
heads_share
else
n_heads
rel_stddev
=
self
.
k_channels
**-
0.5
self
.
emb_rel_k
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
emb_rel_v
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
conv_o
=
torch
.
nn
.
Conv1d
(
channels
,
out_channels
,
1
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_q
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_k
.
weight
)
if
proximal_init
:
self
.
conv_k
.
weight
.
data
.
copy_
(
self
.
conv_q
.
weight
.
data
)
self
.
conv_k
.
bias
.
data
.
copy_
(
self
.
conv_q
.
bias
.
data
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_v
.
weight
)
def
forward
(
self
,
x
,
c
,
attn_mask
=
None
):
q
=
self
.
conv_q
(
x
)
k
=
self
.
conv_k
(
c
)
v
=
self
.
conv_v
(
c
)
x
,
self
.
attn
=
self
.
attention
(
q
,
k
,
v
,
mask
=
attn_mask
)
x
=
self
.
conv_o
(
x
)
return
x
def
attention
(
self
,
query
,
key
,
value
,
mask
=
None
):
b
,
d
,
t_s
,
t_t
=
(
*
key
.
size
(),
query
.
size
(
2
))
query
=
query
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_t
).
transpose
(
2
,
3
)
key
=
key
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
value
=
value
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
scores
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
k_channels
)
if
self
.
window_size
is
not
None
:
assert
t_s
==
t_t
,
"Relative attention is only available for self-attention."
key_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_k
,
t_s
)
rel_logits
=
self
.
_matmul_with_relative_keys
(
query
,
key_relative_embeddings
)
rel_logits
=
self
.
_relative_position_to_absolute_position
(
rel_logits
)
scores_local
=
rel_logits
/
math
.
sqrt
(
self
.
k_channels
)
scores
=
scores
+
scores_local
if
self
.
proximal_bias
:
assert
t_s
==
t_t
,
"Proximal bias is only available for self-attention."
scores
=
scores
+
self
.
_attention_bias_proximal
(
t_s
).
to
(
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e4
)
p_attn
=
torch
.
nn
.
functional
.
softmax
(
scores
,
dim
=-
1
)
p_attn
=
self
.
drop
(
p_attn
)
output
=
torch
.
matmul
(
p_attn
,
value
)
if
self
.
window_size
is
not
None
:
relative_weights
=
self
.
_absolute_position_to_relative_position
(
p_attn
)
value_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_v
,
t_s
)
output
=
output
+
self
.
_matmul_with_relative_values
(
relative_weights
,
value_relative_embeddings
)
output
=
output
.
transpose
(
2
,
3
).
contiguous
().
view
(
b
,
d
,
t_t
)
return
output
,
p_attn
def
_matmul_with_relative_values
(
self
,
x
,
y
):
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
))
return
ret
def
_matmul_with_relative_keys
(
self
,
x
,
y
):
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
).
transpose
(
-
2
,
-
1
))
return
ret
def
_get_relative_embeddings
(
self
,
relative_embeddings
,
length
):
pad_length
=
max
(
length
-
(
self
.
window_size
+
1
),
0
)
slice_start_position
=
max
((
self
.
window_size
+
1
)
-
length
,
0
)
slice_end_position
=
slice_start_position
+
2
*
length
-
1
if
pad_length
>
0
:
padded_relative_embeddings
=
torch
.
nn
.
functional
.
pad
(
relative_embeddings
,
convert_pad_shape
([[
0
,
0
],
[
pad_length
,
pad_length
],
[
0
,
0
]]))
else
:
padded_relative_embeddings
=
relative_embeddings
used_relative_embeddings
=
padded_relative_embeddings
[:,
slice_start_position
:
slice_end_position
]
return
used_relative_embeddings
def
_relative_position_to_absolute_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
0
],[
0
,
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
*
2
*
length
])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
length
-
1
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
+
1
,
2
*
length
-
1
])[:,
:,
:
length
,
length
-
1
:]
return
x_final
def
_absolute_position_to_relative_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
**
2
+
length
*
(
length
-
1
)])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
length
,
0
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
,
2
*
length
])[:,:,:,
1
:]
return
x_final
def
_attention_bias_proximal
(
self
,
length
):
r
=
torch
.
arange
(
length
,
dtype
=
torch
.
float32
)
diff
=
torch
.
unsqueeze
(
r
,
0
)
-
torch
.
unsqueeze
(
r
,
1
)
return
torch
.
unsqueeze
(
torch
.
unsqueeze
(
-
torch
.
log1p
(
torch
.
abs
(
diff
)),
0
),
0
)
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
0.0
):
super
(
FFN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
filter_channels
=
filter_channels
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
x
*
x_mask
)
return
x
*
x_mask
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
kernel_size
=
1
,
p_dropout
=
0.0
,
window_size
=
None
,
**
kwargs
):
super
(
Encoder
,
self
).
__init__
()
self
.
hidden_channels
=
hidden_channels
self
.
filter_channels
=
filter_channels
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
attn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_1
=
torch
.
nn
.
ModuleList
()
self
.
ffn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_2
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
self
.
n_layers
):
self
.
attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
window_size
=
window_size
,
p_dropout
=
p_dropout
))
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
ffn_layers
.
append
(
FFN
(
hidden_channels
,
hidden_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
))
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
def
forward
(
self
,
x
,
x_mask
):
attn_mask
=
x_mask
.
unsqueeze
(
2
)
*
x_mask
.
unsqueeze
(
-
1
)
for
i
in
range
(
self
.
n_layers
):
x
=
x
*
x_mask
y
=
self
.
attn_layers
[
i
](
x
,
x
,
attn_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_1
[
i
](
x
+
y
)
y
=
self
.
ffn_layers
[
i
](
x
,
x_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_2
[
i
](
x
+
y
)
x
=
x
*
x_mask
return
x
class
TextEncoder
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
n_vocab
,
n_feats
,
n_channels
,
filter_channels
,
filter_channels_dp
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
None
,
spk_emb_dim
=
64
,
n_spks
=
1
):
super
(
TextEncoder
,
self
).
__init__
()
self
.
register
(
n_vocab
=
n_vocab
,
n_feats
=
n_feats
,
n_channels
=
n_channels
,
filter_channels
=
filter_channels
,
filter_channels_dp
=
filter_channels_dp
,
n_heads
=
n_heads
,
n_layers
=
n_layers
,
kernel_size
=
kernel_size
,
p_dropout
=
p_dropout
,
window_size
=
window_size
,
spk_emb_dim
=
spk_emb_dim
,
n_spks
=
n_spks
)
self
.
n_vocab
=
n_vocab
self
.
n_feats
=
n_feats
self
.
n_channels
=
n_channels
self
.
filter_channels
=
filter_channels
self
.
filter_channels_dp
=
filter_channels_dp
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
spk_emb_dim
=
spk_emb_dim
self
.
n_spks
=
n_spks
self
.
emb
=
torch
.
nn
.
Embedding
(
n_vocab
,
n_channels
)
torch
.
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
0.0
,
n_channels
**-
0.5
)
self
.
prenet
=
ConvReluNorm
(
n_channels
,
n_channels
,
n_channels
,
kernel_size
=
5
,
n_layers
=
3
,
p_dropout
=
0.5
)
self
.
encoder
=
Encoder
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
window_size
)
self
.
proj_m
=
torch
.
nn
.
Conv1d
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
n_feats
,
1
)
self
.
proj_w
=
DurationPredictor
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels_dp
,
kernel_size
,
p_dropout
)
def
forward
(
self
,
x
,
x_lengths
,
spk
=
None
):
x
=
self
.
emb
(
x
)
*
math
.
sqrt
(
self
.
n_channels
)
x
=
torch
.
transpose
(
x
,
1
,
-
1
)
x_mask
=
torch
.
unsqueeze
(
sequence_mask
(
x_lengths
,
x
.
size
(
2
)),
1
).
to
(
x
.
dtype
)
x
=
self
.
prenet
(
x
,
x_mask
)
if
self
.
n_spks
>
1
:
x
=
torch
.
cat
([
x
,
spk
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])],
dim
=
1
)
x
=
self
.
encoder
(
x
,
x_mask
)
mu
=
self
.
proj_m
(
x
)
*
x_mask
x_dp
=
torch
.
detach
(
x
)
logw
=
self
.
proj_w
(
x_dp
,
x_mask
)
return
mu
,
logw
,
x_mask
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
850d4345
...
...
@@ -903,8 +903,8 @@ class LatentDiffusion(DiffusionPipeline):
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
)
image
=
image
.
to
(
torch_device
)
)
.
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
...
...
@@ -937,46 +937,17 @@ class LatentDiffusion(DiffusionPipeline):
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. get actual t and t-1
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
# 3. compute alphas, betas
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t_prev
=
self
.
noise_scheduler
.
get_alpha_prod
(
prev_train_step
)
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
# 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
()
*
pred_noise_t
)
/
alpha_prod_t
.
sqrt
()
# Second: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t
=
(
beta_prod_t_prev
/
beta_prod_t
).
sqrt
()
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
).
sqrt
()
std_dev_t
=
eta
*
std_dev_t
# Third: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
).
sqrt
()
*
pred_noise_t
# Forth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image
=
alpha_prod_t_prev
.
sqrt
()
*
pred_original_image
+
pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if
eta
>
0.0
:
noise
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
generator
=
generator
,
)
noise
=
noise
.
to
(
torch_device
)
prev_image
=
pred_prev_image
+
std_dev_t
*
noise
else
:
prev_image
=
pred_prev_image
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
#
6
.
S
et current image to prev_image: x_t -> x_t-1
image
=
prev_image
#
4
.
s
et current image to prev_image: x_t -> x_t-1
image
=
pred_
prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment