Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f7d91f8b
Unverified
Commit
f7d91f8b
authored
Jun 15, 2022
by
Suraj Patil
Committed by
GitHub
Jun 15, 2022
Browse files
Unet for Grad TTS and pipeline
Unet for Grad TTS and pipeline
parents
14a2201f
304d4d90
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
620 additions
and
0 deletions
+620
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+233
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+385
-0
No files found.
src/diffusers/__init__.py
View file @
f7d91f8b
...
...
@@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
,
BDDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
...
...
src/diffusers/models/__init__.py
View file @
f7d91f8b
...
...
@@ -19,3 +19,4 @@
from
.unet
import
UNetModel
from
.unet_glide
import
GLIDEUNetModel
,
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_grad_tts
import
UNetGradTTSModel
\ No newline at end of file
src/diffusers/models/unet_grad_tts.py
0 → 100644
View file @
f7d91f8b
import
math
import
torch
try
:
from
einops
import
rearrange
,
repeat
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
class
Mish
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
torch
.
nn
.
functional
.
softplus
(
x
))
class
Upsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Downsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Downsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim
,
3
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
Rezero
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Rezero
,
self
).
__init__
()
self
.
fn
=
fn
self
.
g
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
x
):
return
self
.
fn
(
x
)
*
self
.
g
class
Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
(
Block
,
self
).
__init__
()
self
.
block
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
),
torch
.
nn
.
GroupNorm
(
groups
,
dim_out
),
Mish
())
def
forward
(
self
,
x
,
mask
):
output
=
self
.
block
(
x
*
mask
)
return
output
*
mask
class
ResnetBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
time_emb_dim
,
groups
=
8
):
super
(
ResnetBlock
,
self
).
__init__
()
self
.
mlp
=
torch
.
nn
.
Sequential
(
Mish
(),
torch
.
nn
.
Linear
(
time_emb_dim
,
dim_out
))
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
if
dim
!=
dim_out
:
self
.
res_conv
=
torch
.
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
else
:
self
.
res_conv
=
torch
.
nn
.
Identity
()
def
forward
(
self
,
x
,
mask
,
time_emb
):
h
=
self
.
block1
(
x
,
mask
)
h
+=
self
.
mlp
(
time_emb
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
h
=
self
.
block2
(
h
,
mask
)
output
=
h
+
self
.
res_conv
(
x
*
mask
)
return
output
class
LinearAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
(
LinearAttention
,
self
).
__init__
()
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
torch
.
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
torch
.
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
)
q
,
k
,
v
=
rearrange
(
qkv
,
'b (qkv heads c) h w -> qkv b heads c (h w)'
,
heads
=
self
.
heads
,
qkv
=
3
)
k
=
k
.
softmax
(
dim
=-
1
)
context
=
torch
.
einsum
(
'bhdn,bhen->bhde'
,
k
,
v
)
out
=
torch
.
einsum
(
'bhde,bhdn->bhen'
,
context
,
q
)
out
=
rearrange
(
out
,
'b heads c (h w) -> b (heads c) h w'
,
heads
=
self
.
heads
,
h
=
h
,
w
=
w
)
return
self
.
to_out
(
out
)
class
Residual
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
(
Residual
,
self
).
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
output
=
self
.
fn
(
x
,
*
args
,
**
kwargs
)
+
x
return
output
class
SinusoidalPosEmb
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
SinusoidalPosEmb
,
self
).
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
,
scale
=
1000
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
).
float
()
*
-
emb
)
emb
=
scale
*
x
.
unsqueeze
(
1
)
*
emb
.
unsqueeze
(
0
)
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
class
UNetGradTTSModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
dim
,
dim_mults
=
(
1
,
2
,
4
),
groups
=
8
,
n_spks
=
None
,
spk_emb_dim
=
64
,
n_feats
=
80
,
pe_scale
=
1000
):
super
(
UNetGradTTSModel
,
self
).
__init__
()
self
.
register
(
dim
=
dim
,
dim_mults
=
dim_mults
,
groups
=
groups
,
n_spks
=
n_spks
,
spk_emb_dim
=
spk_emb_dim
,
n_feats
=
n_feats
,
pe_scale
=
pe_scale
)
self
.
dim
=
dim
self
.
dim_mults
=
dim_mults
self
.
groups
=
groups
self
.
n_spks
=
n_spks
if
not
isinstance
(
n_spks
,
type
(
None
))
else
1
self
.
spk_emb_dim
=
spk_emb_dim
self
.
pe_scale
=
pe_scale
if
n_spks
>
1
:
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
))
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
dims
=
[
2
+
(
1
if
n_spks
>
1
else
0
),
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
self
.
downs
=
torch
.
nn
.
ModuleList
([])
self
.
ups
=
torch
.
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
torch
.
nn
.
ModuleList
([
ResnetBlock
(
dim_in
,
dim_out
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_out
,
dim_out
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
torch
.
nn
.
Identity
()]))
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
self
.
mid_attn
=
Residual
(
Rezero
(
LinearAttention
(
mid_dim
)))
self
.
mid_block2
=
ResnetBlock
(
mid_dim
,
mid_dim
,
time_emb_dim
=
dim
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
self
.
ups
.
append
(
torch
.
nn
.
ModuleList
([
ResnetBlock
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
dim
),
ResnetBlock
(
dim_in
,
dim_in
,
time_emb_dim
=
dim
),
Residual
(
Rezero
(
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
)]))
self
.
final_block
=
Block
(
dim
,
dim
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
t
=
self
.
time_pos_emb
(
t
,
scale
=
self
.
pe_scale
)
t
=
self
.
mlp
(
t
)
if
self
.
n_spks
<
2
:
x
=
torch
.
stack
([
mu
,
x
],
1
)
else
:
s
=
s
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])
x
=
torch
.
stack
([
mu
,
x
,
s
],
1
)
mask
=
mask
.
unsqueeze
(
1
)
hiddens
=
[]
masks
=
[
mask
]
for
resnet1
,
resnet2
,
attn
,
downsample
in
self
.
downs
:
mask_down
=
masks
[
-
1
]
x
=
resnet1
(
x
,
mask_down
,
t
)
x
=
resnet2
(
x
,
mask_down
,
t
)
x
=
attn
(
x
)
hiddens
.
append
(
x
)
x
=
downsample
(
x
*
mask_down
)
masks
.
append
(
mask_down
[:,
:,
:,
::
2
])
masks
=
masks
[:
-
1
]
mask_mid
=
masks
[
-
1
]
x
=
self
.
mid_block1
(
x
,
mask_mid
,
t
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid_block2
(
x
,
mask_mid
,
t
)
for
resnet1
,
resnet2
,
attn
,
upsample
in
self
.
ups
:
mask_up
=
masks
.
pop
()
x
=
torch
.
cat
((
x
,
hiddens
.
pop
()),
dim
=
1
)
x
=
resnet1
(
x
,
mask_up
,
t
)
x
=
resnet2
(
x
,
mask_up
,
t
)
x
=
attn
(
x
)
x
=
upsample
(
x
*
mask_up
)
x
=
self
.
final_block
(
x
,
mask
)
output
=
self
.
final_conv
(
x
*
mask
)
return
(
output
*
mask
).
squeeze
(
1
)
\ No newline at end of file
src/diffusers/pipelines/pipeline_grad_tts.py
0 → 100644
View file @
f7d91f8b
""" from https://github.com/jaywalnut310/glow-tts """
import
math
import
torch
from
torch
import
nn
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
int
(
max_length
),
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
def
fix_len_compatibility
(
length
,
num_downsamplings_in_unet
=
2
):
while
True
:
if
length
%
(
2
**
num_downsamplings_in_unet
)
==
0
:
return
length
length
+=
1
def
convert_pad_shape
(
pad_shape
):
l
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
l
for
item
in
sublist
]
return
pad_shape
def
generate_path
(
duration
,
mask
):
device
=
duration
.
device
b
,
t_x
,
t_y
=
mask
.
shape
cum_duration
=
torch
.
cumsum
(
duration
,
1
)
path
=
torch
.
zeros
(
b
,
t_x
,
t_y
,
dtype
=
mask
.
dtype
).
to
(
device
=
device
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
*
mask
return
path
def
duration_loss
(
logw
,
logw_
,
lengths
):
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
return
loss
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
eps
=
1e-4
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
channels
=
channels
self
.
eps
=
eps
self
.
gamma
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
channels
))
self
.
beta
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
channels
))
def
forward
(
self
,
x
):
n_dims
=
len
(
x
.
shape
)
mean
=
torch
.
mean
(
x
,
1
,
keepdim
=
True
)
variance
=
torch
.
mean
((
x
-
mean
)
**
2
,
1
,
keepdim
=
True
)
x
=
(
x
-
mean
)
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
shape
=
[
1
,
-
1
]
+
[
1
]
*
(
n_dims
-
2
)
x
=
x
*
self
.
gamma
.
view
(
*
shape
)
+
self
.
beta
.
view
(
*
shape
)
return
x
class
ConvReluNorm
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_channels
,
out_channels
,
kernel_size
,
n_layers
,
p_dropout
):
super
(
ConvReluNorm
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
hidden_channels
=
hidden_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
n_layers
=
n_layers
self
.
p_dropout
=
p_dropout
self
.
conv_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers
=
torch
.
nn
.
ModuleList
()
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
in_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
relu_drop
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
p_dropout
))
for
_
in
range
(
n_layers
-
1
):
self
.
conv_layers
.
append
(
torch
.
nn
.
Conv1d
(
hidden_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
))
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
proj
=
torch
.
nn
.
Conv1d
(
hidden_channels
,
out_channels
,
1
)
self
.
proj
.
weight
.
data
.
zero_
()
self
.
proj
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
x_mask
):
x_org
=
x
for
i
in
range
(
self
.
n_layers
):
x
=
self
.
conv_layers
[
i
](
x
*
x_mask
)
x
=
self
.
norm_layers
[
i
](
x
)
x
=
self
.
relu_drop
(
x
)
x
=
x_org
+
self
.
proj
(
x
)
return
x
*
x_mask
class
DurationPredictor
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
filter_channels
,
kernel_size
,
p_dropout
):
super
(
DurationPredictor
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
filter_channels
=
filter_channels
self
.
p_dropout
=
p_dropout
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_1
=
LayerNorm
(
filter_channels
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
norm_2
=
LayerNorm
(
filter_channels
)
self
.
proj
=
torch
.
nn
.
Conv1d
(
filter_channels
,
1
,
1
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
norm_1
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
norm_2
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
proj
(
x
*
x_mask
)
return
x
*
x_mask
class
MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
out_channels
,
n_heads
,
window_size
=
None
,
heads_share
=
True
,
p_dropout
=
0.0
,
proximal_bias
=
False
,
proximal_init
=
False
):
super
(
MultiHeadAttention
,
self
).
__init__
()
assert
channels
%
n_heads
==
0
self
.
channels
=
channels
self
.
out_channels
=
out_channels
self
.
n_heads
=
n_heads
self
.
window_size
=
window_size
self
.
heads_share
=
heads_share
self
.
proximal_bias
=
proximal_bias
self
.
p_dropout
=
p_dropout
self
.
attn
=
None
self
.
k_channels
=
channels
//
n_heads
self
.
conv_q
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_k
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_v
=
torch
.
nn
.
Conv1d
(
channels
,
channels
,
1
)
if
window_size
is
not
None
:
n_heads_rel
=
1
if
heads_share
else
n_heads
rel_stddev
=
self
.
k_channels
**-
0.5
self
.
emb_rel_k
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
emb_rel_v
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
conv_o
=
torch
.
nn
.
Conv1d
(
channels
,
out_channels
,
1
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_q
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_k
.
weight
)
if
proximal_init
:
self
.
conv_k
.
weight
.
data
.
copy_
(
self
.
conv_q
.
weight
.
data
)
self
.
conv_k
.
bias
.
data
.
copy_
(
self
.
conv_q
.
bias
.
data
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv_v
.
weight
)
def
forward
(
self
,
x
,
c
,
attn_mask
=
None
):
q
=
self
.
conv_q
(
x
)
k
=
self
.
conv_k
(
c
)
v
=
self
.
conv_v
(
c
)
x
,
self
.
attn
=
self
.
attention
(
q
,
k
,
v
,
mask
=
attn_mask
)
x
=
self
.
conv_o
(
x
)
return
x
def
attention
(
self
,
query
,
key
,
value
,
mask
=
None
):
b
,
d
,
t_s
,
t_t
=
(
*
key
.
size
(),
query
.
size
(
2
))
query
=
query
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_t
).
transpose
(
2
,
3
)
key
=
key
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
value
=
value
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
scores
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
k_channels
)
if
self
.
window_size
is
not
None
:
assert
t_s
==
t_t
,
"Relative attention is only available for self-attention."
key_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_k
,
t_s
)
rel_logits
=
self
.
_matmul_with_relative_keys
(
query
,
key_relative_embeddings
)
rel_logits
=
self
.
_relative_position_to_absolute_position
(
rel_logits
)
scores_local
=
rel_logits
/
math
.
sqrt
(
self
.
k_channels
)
scores
=
scores
+
scores_local
if
self
.
proximal_bias
:
assert
t_s
==
t_t
,
"Proximal bias is only available for self-attention."
scores
=
scores
+
self
.
_attention_bias_proximal
(
t_s
).
to
(
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e4
)
p_attn
=
torch
.
nn
.
functional
.
softmax
(
scores
,
dim
=-
1
)
p_attn
=
self
.
drop
(
p_attn
)
output
=
torch
.
matmul
(
p_attn
,
value
)
if
self
.
window_size
is
not
None
:
relative_weights
=
self
.
_absolute_position_to_relative_position
(
p_attn
)
value_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_v
,
t_s
)
output
=
output
+
self
.
_matmul_with_relative_values
(
relative_weights
,
value_relative_embeddings
)
output
=
output
.
transpose
(
2
,
3
).
contiguous
().
view
(
b
,
d
,
t_t
)
return
output
,
p_attn
def
_matmul_with_relative_values
(
self
,
x
,
y
):
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
))
return
ret
def
_matmul_with_relative_keys
(
self
,
x
,
y
):
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
).
transpose
(
-
2
,
-
1
))
return
ret
def
_get_relative_embeddings
(
self
,
relative_embeddings
,
length
):
pad_length
=
max
(
length
-
(
self
.
window_size
+
1
),
0
)
slice_start_position
=
max
((
self
.
window_size
+
1
)
-
length
,
0
)
slice_end_position
=
slice_start_position
+
2
*
length
-
1
if
pad_length
>
0
:
padded_relative_embeddings
=
torch
.
nn
.
functional
.
pad
(
relative_embeddings
,
convert_pad_shape
([[
0
,
0
],
[
pad_length
,
pad_length
],
[
0
,
0
]]))
else
:
padded_relative_embeddings
=
relative_embeddings
used_relative_embeddings
=
padded_relative_embeddings
[:,
slice_start_position
:
slice_end_position
]
return
used_relative_embeddings
def
_relative_position_to_absolute_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
0
],[
0
,
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
*
2
*
length
])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],[
0
,
0
],[
0
,
length
-
1
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
+
1
,
2
*
length
-
1
])[:,
:,
:
length
,
length
-
1
:]
return
x_final
def
_absolute_position_to_relative_position
(
self
,
x
):
batch
,
heads
,
length
,
_
=
x
.
size
()
x
=
torch
.
nn
.
functional
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]]))
x_flat
=
x
.
view
([
batch
,
heads
,
length
**
2
+
length
*
(
length
-
1
)])
x_flat
=
torch
.
nn
.
functional
.
pad
(
x_flat
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
length
,
0
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
,
2
*
length
])[:,:,:,
1
:]
return
x_final
def
_attention_bias_proximal
(
self
,
length
):
r
=
torch
.
arange
(
length
,
dtype
=
torch
.
float32
)
diff
=
torch
.
unsqueeze
(
r
,
0
)
-
torch
.
unsqueeze
(
r
,
1
)
return
torch
.
unsqueeze
(
torch
.
unsqueeze
(
-
torch
.
log1p
(
torch
.
abs
(
diff
)),
0
),
0
)
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
0.0
):
super
(
FFN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
filter_channels
=
filter_channels
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
conv_1
=
torch
.
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
conv_2
=
torch
.
nn
.
Conv1d
(
filter_channels
,
out_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
x
*
x_mask
)
x
=
torch
.
relu
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
x
*
x_mask
)
return
x
*
x_mask
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
kernel_size
=
1
,
p_dropout
=
0.0
,
window_size
=
None
,
**
kwargs
):
super
(
Encoder
,
self
).
__init__
()
self
.
hidden_channels
=
hidden_channels
self
.
filter_channels
=
filter_channels
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
drop
=
torch
.
nn
.
Dropout
(
p_dropout
)
self
.
attn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_1
=
torch
.
nn
.
ModuleList
()
self
.
ffn_layers
=
torch
.
nn
.
ModuleList
()
self
.
norm_layers_2
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
self
.
n_layers
):
self
.
attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
window_size
=
window_size
,
p_dropout
=
p_dropout
))
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
ffn_layers
.
append
(
FFN
(
hidden_channels
,
hidden_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
))
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
def
forward
(
self
,
x
,
x_mask
):
attn_mask
=
x_mask
.
unsqueeze
(
2
)
*
x_mask
.
unsqueeze
(
-
1
)
for
i
in
range
(
self
.
n_layers
):
x
=
x
*
x_mask
y
=
self
.
attn_layers
[
i
](
x
,
x
,
attn_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_1
[
i
](
x
+
y
)
y
=
self
.
ffn_layers
[
i
](
x
,
x_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_2
[
i
](
x
+
y
)
x
=
x
*
x_mask
return
x
class
TextEncoder
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
n_vocab
,
n_feats
,
n_channels
,
filter_channels
,
filter_channels_dp
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
None
,
spk_emb_dim
=
64
,
n_spks
=
1
):
super
(
TextEncoder
,
self
).
__init__
()
self
.
register
(
n_vocab
=
n_vocab
,
n_feats
=
n_feats
,
n_channels
=
n_channels
,
filter_channels
=
filter_channels
,
filter_channels_dp
=
filter_channels_dp
,
n_heads
=
n_heads
,
n_layers
=
n_layers
,
kernel_size
=
kernel_size
,
p_dropout
=
p_dropout
,
window_size
=
window_size
,
spk_emb_dim
=
spk_emb_dim
,
n_spks
=
n_spks
)
self
.
n_vocab
=
n_vocab
self
.
n_feats
=
n_feats
self
.
n_channels
=
n_channels
self
.
filter_channels
=
filter_channels
self
.
filter_channels_dp
=
filter_channels_dp
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
spk_emb_dim
=
spk_emb_dim
self
.
n_spks
=
n_spks
self
.
emb
=
torch
.
nn
.
Embedding
(
n_vocab
,
n_channels
)
torch
.
nn
.
init
.
normal_
(
self
.
emb
.
weight
,
0.0
,
n_channels
**-
0.5
)
self
.
prenet
=
ConvReluNorm
(
n_channels
,
n_channels
,
n_channels
,
kernel_size
=
5
,
n_layers
=
3
,
p_dropout
=
0.5
)
self
.
encoder
=
Encoder
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels
,
n_heads
,
n_layers
,
kernel_size
,
p_dropout
,
window_size
=
window_size
)
self
.
proj_m
=
torch
.
nn
.
Conv1d
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
n_feats
,
1
)
self
.
proj_w
=
DurationPredictor
(
n_channels
+
(
spk_emb_dim
if
n_spks
>
1
else
0
),
filter_channels_dp
,
kernel_size
,
p_dropout
)
def
forward
(
self
,
x
,
x_lengths
,
spk
=
None
):
x
=
self
.
emb
(
x
)
*
math
.
sqrt
(
self
.
n_channels
)
x
=
torch
.
transpose
(
x
,
1
,
-
1
)
x_mask
=
torch
.
unsqueeze
(
sequence_mask
(
x_lengths
,
x
.
size
(
2
)),
1
).
to
(
x
.
dtype
)
x
=
self
.
prenet
(
x
,
x_mask
)
if
self
.
n_spks
>
1
:
x
=
torch
.
cat
([
x
,
spk
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
-
1
])],
dim
=
1
)
x
=
self
.
encoder
(
x
,
x_mask
)
mu
=
self
.
proj_m
(
x
)
*
x_mask
x_dp
=
torch
.
detach
(
x
)
logw
=
self
.
proj_w
(
x_dp
,
x_mask
)
return
mu
,
logw
,
x_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment