Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
index-tts-vllm
Commits
ab9c00af
Commit
ab9c00af
authored
Jan 07, 2026
by
yangzhong
Browse files
init submission
parents
Pipeline
#3176
failed with stages
in 0 seconds
Changes
316
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4048 additions
and
0 deletions
+4048
-0
indextts/utils/maskgct/models/codec/codec_sampler.py
indextts/utils/maskgct/models/codec/codec_sampler.py
+126
-0
indextts/utils/maskgct/models/codec/codec_trainer.py
indextts/utils/maskgct/models/codec/codec_trainer.py
+166
-0
indextts/utils/maskgct/models/codec/facodec/__init__.py
indextts/utils/maskgct/models/codec/facodec/__init__.py
+0
-0
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py
...maskgct/models/codec/facodec/alias_free_torch/__init__.py
+5
-0
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py
...tils/maskgct/models/codec/facodec/alias_free_torch/act.py
+29
-0
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py
...s/maskgct/models/codec/facodec/alias_free_torch/filter.py
+96
-0
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py
...maskgct/models/codec/facodec/alias_free_torch/resample.py
+57
-0
indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py
...tts/utils/maskgct/models/codec/facodec/facodec_dataset.py
+98
-0
indextts/utils/maskgct/models/codec/facodec/facodec_inference.py
...s/utils/maskgct/models/codec/facodec/facodec_inference.py
+137
-0
indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py
...tts/utils/maskgct/models/codec/facodec/facodec_trainer.py
+776
-0
indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py
...tils/maskgct/models/codec/facodec/modules/JDC/__init__.py
+1
-0
indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7
...tts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7
+0
-0
indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py
...s/utils/maskgct/models/codec/facodec/modules/JDC/model.py
+219
-0
indextts/utils/maskgct/models/codec/facodec/modules/attentions.py
.../utils/maskgct/models/codec/facodec/modules/attentions.py
+437
-0
indextts/utils/maskgct/models/codec/facodec/modules/commons.py
...tts/utils/maskgct/models/codec/facodec/modules/commons.py
+331
-0
indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py
...maskgct/models/codec/facodec/modules/gradient_reversal.py
+35
-0
indextts/utils/maskgct/models/codec/facodec/modules/layers.py
...xtts/utils/maskgct/models/codec/facodec/modules/layers.py
+460
-0
indextts/utils/maskgct/models/codec/facodec/modules/quantize.py
...ts/utils/maskgct/models/codec/facodec/modules/quantize.py
+741
-0
indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py
...ils/maskgct/models/codec/facodec/modules/style_encoder.py
+110
-0
indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py
...tts/utils/maskgct/models/codec/facodec/modules/wavenet.py
+224
-0
No files found.
indextts/utils/maskgct/models/codec/codec_sampler.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
random
from
torch.utils.data
import
ConcatDataset
,
Dataset
from
torch.utils.data.sampler
import
(
BatchSampler
,
RandomSampler
,
Sampler
,
SequentialSampler
,
)
class
ScheduledSampler
(
Sampler
):
"""A sampler that samples data from a given concat-dataset.
Args:
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
batch_size (int): batch size
holistic_shuffle (bool): whether to shuffle the whole dataset or not
logger (logging.Logger): logger to print warning message
Usage:
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
[3, 4, 5, 0, 1, 2, 6, 7, 8]
"""
def
__init__
(
self
,
concat_dataset
,
batch_size
,
holistic_shuffle
,
logger
=
None
,
type
=
"train"
):
if
not
isinstance
(
concat_dataset
,
ConcatDataset
):
raise
ValueError
(
"concat_dataset must be an instance of ConcatDataset, but got {}"
.
format
(
type
(
concat_dataset
)
)
)
if
not
isinstance
(
batch_size
,
int
):
raise
ValueError
(
"batch_size must be an integer, but got {}"
.
format
(
type
(
batch_size
))
)
if
not
isinstance
(
holistic_shuffle
,
bool
):
raise
ValueError
(
"holistic_shuffle must be a boolean, but got {}"
.
format
(
type
(
holistic_shuffle
)
)
)
self
.
concat_dataset
=
concat_dataset
self
.
batch_size
=
batch_size
self
.
holistic_shuffle
=
holistic_shuffle
affected_dataset_name
=
[]
affected_dataset_len
=
[]
for
dataset
in
concat_dataset
.
datasets
:
dataset_len
=
len
(
dataset
)
dataset_name
=
dataset
.
get_dataset_name
()
if
dataset_len
<
batch_size
:
affected_dataset_name
.
append
(
dataset_name
)
affected_dataset_len
.
append
(
dataset_len
)
self
.
type
=
type
for
dataset_name
,
dataset_len
in
zip
(
affected_dataset_name
,
affected_dataset_len
):
if
not
type
==
"valid"
:
logger
.
warning
(
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior."
.
format
(
type
,
dataset_name
,
dataset_len
,
batch_size
)
)
def
__len__
(
self
):
# the number of batches with drop last
num_of_batches
=
sum
(
[
math
.
floor
(
len
(
dataset
)
/
self
.
batch_size
)
for
dataset
in
self
.
concat_dataset
.
datasets
]
)
return
num_of_batches
*
self
.
batch_size
def
__iter__
(
self
):
iters
=
[]
for
dataset
in
self
.
concat_dataset
.
datasets
:
iters
.
append
(
SequentialSampler
(
dataset
).
__iter__
()
if
self
.
holistic_shuffle
else
RandomSampler
(
dataset
).
__iter__
()
)
init_indices
=
[
0
]
+
self
.
concat_dataset
.
cumulative_sizes
[:
-
1
]
output_batches
=
[]
for
dataset_idx
in
range
(
len
(
self
.
concat_dataset
.
datasets
)):
cur_batch
=
[]
for
idx
in
iters
[
dataset_idx
]:
cur_batch
.
append
(
idx
+
init_indices
[
dataset_idx
])
if
len
(
cur_batch
)
==
self
.
batch_size
:
output_batches
.
append
(
cur_batch
)
cur_batch
=
[]
if
self
.
type
==
"valid"
and
len
(
cur_batch
)
>
0
:
output_batches
.
append
(
cur_batch
)
cur_batch
=
[]
# force drop last in training
random
.
shuffle
(
output_batches
)
output_indices
=
[
item
for
sublist
in
output_batches
for
item
in
sublist
]
return
iter
(
output_indices
)
def
build_samplers
(
concat_dataset
:
Dataset
,
cfg
,
logger
,
type
):
sampler
=
ScheduledSampler
(
concat_dataset
,
cfg
.
train
.
batch_size
,
cfg
.
train
.
sampler
.
holistic_shuffle
,
logger
,
type
,
)
batch_sampler
=
BatchSampler
(
sampler
,
cfg
.
train
.
batch_size
,
cfg
.
train
.
sampler
.
drop_last
if
not
type
==
"valid"
else
False
,
)
return
sampler
,
batch_sampler
indextts/utils/maskgct/models/codec/codec_trainer.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
random
from
pathlib
import
Path
import
re
import
accelerate
import
json5
import
numpy
as
np
import
torch
from
accelerate.utils
import
ProjectConfiguration
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
models.codec.codec_sampler
import
build_samplers
class
CodecTrainer
:
def
__init__
(
self
):
super
().
__init__
()
def
_init_accelerator
(
self
):
"""Initialize the accelerator components."""
self
.
exp_dir
=
os
.
path
.
join
(
os
.
path
.
abspath
(
self
.
cfg
.
log_dir
),
self
.
args
.
exp_name
)
project_config
=
ProjectConfiguration
(
project_dir
=
self
.
exp_dir
,
logging_dir
=
os
.
path
.
join
(
self
.
exp_dir
,
"log"
)
)
self
.
accelerator
=
accelerate
.
Accelerator
(
gradient_accumulation_steps
=
self
.
cfg
.
train
.
gradient_accumulation_step
,
log_with
=
self
.
cfg
.
train
.
tracker
,
project_config
=
project_config
,
)
if
self
.
accelerator
.
is_main_process
:
os
.
makedirs
(
project_config
.
project_dir
,
exist_ok
=
True
)
os
.
makedirs
(
project_config
.
logging_dir
,
exist_ok
=
True
)
with
self
.
accelerator
.
main_process_first
():
self
.
accelerator
.
init_trackers
(
self
.
args
.
exp_name
)
def
_build_dataset
(
self
):
pass
def
_build_criterion
(
self
):
pass
def
_build_model
(
self
):
pass
def
_build_dataloader
(
self
):
"""Build dataloader which merges a series of datasets."""
# Build dataset instance for each dataset and combine them by ConcatDataset
Dataset
,
Collator
=
self
.
_build_dataset
()
# Build train set
train_dataset
=
Dataset
(
self
.
cfg
,
self
.
cfg
.
dataset
,
is_valid
=
False
)
train_collate
=
Collator
(
self
.
cfg
)
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
,
num_replicas
=
self
.
accelerator
.
num_processes
,
rank
=
self
.
accelerator
.
local_process_index
,
shuffle
=
True
,
seed
=
self
.
cfg
.
train
.
random_seed
,
)
train_loader
=
DataLoader
(
train_dataset
,
batch_size
=
self
.
cfg
.
train
.
batch_size
,
collate_fn
=
train_collate
,
sampler
=
sampler
,
num_workers
=
self
.
cfg
.
train
.
dataloader
.
num_worker
,
pin_memory
=
self
.
cfg
.
train
.
dataloader
.
pin_memory
,
)
return
train_loader
,
None
def
_build_optimizer
(
self
):
pass
def
_build_scheduler
(
self
):
pass
def
_load_model
(
self
,
checkpoint_dir
,
checkpoint_path
=
None
,
resume_type
=
"resume"
):
"""Load model from checkpoint. If a folder is given, it will
load the latest checkpoint in checkpoint_dir. If a path is given
it will load the checkpoint specified by checkpoint_path.
**Only use this method after** ``accelerator.prepare()``.
"""
if
checkpoint_path
is
None
:
ls
=
[
str
(
i
)
for
i
in
Path
(
checkpoint_dir
).
glob
(
"*"
)]
ls
.
sort
(
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
3
].
split
(
"-"
)[
-
1
]),
reverse
=
True
)
checkpoint_path
=
ls
[
0
]
if
resume_type
==
"resume"
:
self
.
accelerator
.
load_state
(
checkpoint_path
)
elif
resume_type
==
"finetune"
:
accelerate
.
load_checkpoint_and_dispatch
(
self
.
accelerator
.
unwrap_model
(
self
.
model
),
os
.
path
.
join
(
checkpoint_path
,
"pytorch_model.bin"
),
)
self
.
logger
.
info
(
"Load model weights for finetune SUCCESS!"
)
else
:
raise
ValueError
(
"Unsupported resume type: {}"
.
format
(
resume_type
))
self
.
epoch
=
int
(
checkpoint_path
.
split
(
"_"
)[
-
3
].
split
(
"-"
)[
-
1
])
+
1
self
.
step
=
int
(
checkpoint_path
.
split
(
"_"
)[
-
2
].
split
(
"-"
)[
-
1
])
+
1
return
checkpoint_path
def
train_loop
(
self
):
pass
def
_train_epoch
(
self
):
pass
def
_valid_epoch
(
self
):
pass
def
_train_step
(
self
):
pass
def
_valid_step
(
self
):
pass
def
_inference
(
self
):
pass
def
_set_random_seed
(
self
,
seed
):
"""Set random seed for all possible random modules."""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
def
_check_nan
(
self
,
loss
):
if
torch
.
any
(
torch
.
isnan
(
loss
)):
self
.
logger
.
fatal
(
"Fatal Error: NaN!"
)
self
.
logger
.
error
(
"loss = {:.6f}"
.
format
(
loss
.
item
()),
in_order
=
True
)
def
_check_basic_configs
(
self
):
if
self
.
cfg
.
train
.
gradient_accumulation_step
<=
0
:
self
.
logger
.
fatal
(
"Invalid gradient_accumulation_step value!"
)
self
.
logger
.
error
(
f
"Invalid gradient_accumulation_step value:
{
self
.
cfg
.
train
.
gradient_accumulation_step
}
. It should be positive."
)
self
.
accelerator
.
end_training
()
raise
ValueError
(
f
"Invalid gradient_accumulation_step value:
{
self
.
cfg
.
train
.
gradient_accumulation_step
}
. It should be positive."
)
def
_count_parameters
(
self
):
pass
def
_dump_cfg
(
self
,
path
):
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
json5
.
dump
(
self
.
cfg
,
open
(
path
,
"w"
),
indent
=
4
,
sort_keys
=
True
,
ensure_ascii
=
False
,
quote_keys
=
True
,
)
def
_is_valid_pattern
(
self
,
directory_name
):
directory_name
=
str
(
directory_name
)
pattern
=
r
"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
return
re
.
match
(
pattern
,
directory_name
)
is
not
None
indextts/utils/maskgct/models/codec/facodec/__init__.py
0 → 100644
View file @
ab9c00af
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py
0 → 100644
View file @
ab9c00af
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
from
.filter
import
*
from
.resample
import
*
from
.act
import
*
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py
0 → 100644
View file @
ab9c00af
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import
torch.nn
as
nn
from
.resample
import
UpSample1d
,
DownSample1d
class
Activation1d
(
nn
.
Module
):
def
__init__
(
self
,
activation
,
up_ratio
:
int
=
2
,
down_ratio
:
int
=
2
,
up_kernel_size
:
int
=
12
,
down_kernel_size
:
int
=
12
,
):
super
().
__init__
()
self
.
up_ratio
=
up_ratio
self
.
down_ratio
=
down_ratio
self
.
act
=
activation
self
.
upsample
=
UpSample1d
(
up_ratio
,
up_kernel_size
)
self
.
downsample
=
DownSample1d
(
down_ratio
,
down_kernel_size
)
# x: [B,C,T]
def
forward
(
self
,
x
):
x
=
self
.
upsample
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
downsample
(
x
)
return
x
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py
0 → 100644
View file @
ab9c00af
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
math
if
"sinc"
in
dir
(
torch
):
sinc
=
torch
.
sinc
else
:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
def
sinc
(
x
:
torch
.
Tensor
):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return
torch
.
where
(
x
==
0
,
torch
.
tensor
(
1.0
,
device
=
x
.
device
,
dtype
=
x
.
dtype
),
torch
.
sin
(
math
.
pi
*
x
)
/
math
.
pi
/
x
,
)
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
def
kaiser_sinc_filter1d
(
cutoff
,
half_width
,
kernel_size
):
# return filter [1,1,kernel_size]
even
=
kernel_size
%
2
==
0
half_size
=
kernel_size
//
2
# For kaiser window
delta_f
=
4
*
half_width
A
=
2.285
*
(
half_size
-
1
)
*
math
.
pi
*
delta_f
+
7.95
if
A
>
50.0
:
beta
=
0.1102
*
(
A
-
8.7
)
elif
A
>=
21.0
:
beta
=
0.5842
*
(
A
-
21
)
**
0.4
+
0.07886
*
(
A
-
21.0
)
else
:
beta
=
0.0
window
=
torch
.
kaiser_window
(
kernel_size
,
beta
=
beta
,
periodic
=
False
)
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if
even
:
time
=
torch
.
arange
(
-
half_size
,
half_size
)
+
0.5
else
:
time
=
torch
.
arange
(
kernel_size
)
-
half_size
if
cutoff
==
0
:
filter_
=
torch
.
zeros_like
(
time
)
else
:
filter_
=
2
*
cutoff
*
window
*
sinc
(
2
*
cutoff
*
time
)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_
/=
filter_
.
sum
()
filter
=
filter_
.
view
(
1
,
1
,
kernel_size
)
return
filter
class
LowPassFilter1d
(
nn
.
Module
):
def
__init__
(
self
,
cutoff
=
0.5
,
half_width
=
0.6
,
stride
:
int
=
1
,
padding
:
bool
=
True
,
padding_mode
:
str
=
"replicate"
,
kernel_size
:
int
=
12
,
):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super
().
__init__
()
if
cutoff
<
-
0.0
:
raise
ValueError
(
"Minimum cutoff must be larger than zero."
)
if
cutoff
>
0.5
:
raise
ValueError
(
"A cutoff above 0.5 does not make sense."
)
self
.
kernel_size
=
kernel_size
self
.
even
=
kernel_size
%
2
==
0
self
.
pad_left
=
kernel_size
//
2
-
int
(
self
.
even
)
self
.
pad_right
=
kernel_size
//
2
self
.
stride
=
stride
self
.
padding
=
padding
self
.
padding_mode
=
padding_mode
filter
=
kaiser_sinc_filter1d
(
cutoff
,
half_width
,
kernel_size
)
self
.
register_buffer
(
"filter"
,
filter
)
# input [B, C, T]
def
forward
(
self
,
x
):
_
,
C
,
_
=
x
.
shape
if
self
.
padding
:
x
=
F
.
pad
(
x
,
(
self
.
pad_left
,
self
.
pad_right
),
mode
=
self
.
padding_mode
)
out
=
F
.
conv1d
(
x
,
self
.
filter
.
expand
(
C
,
-
1
,
-
1
),
stride
=
self
.
stride
,
groups
=
C
)
return
out
indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py
0 → 100644
View file @
ab9c00af
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
from
.filter
import
LowPassFilter1d
from
.filter
import
kaiser_sinc_filter1d
class
UpSample1d
(
nn
.
Module
):
def
__init__
(
self
,
ratio
=
2
,
kernel_size
=
None
):
super
().
__init__
()
self
.
ratio
=
ratio
self
.
kernel_size
=
(
int
(
6
*
ratio
//
2
)
*
2
if
kernel_size
is
None
else
kernel_size
)
self
.
stride
=
ratio
self
.
pad
=
self
.
kernel_size
//
ratio
-
1
self
.
pad_left
=
self
.
pad
*
self
.
stride
+
(
self
.
kernel_size
-
self
.
stride
)
//
2
self
.
pad_right
=
(
self
.
pad
*
self
.
stride
+
(
self
.
kernel_size
-
self
.
stride
+
1
)
//
2
)
filter
=
kaiser_sinc_filter1d
(
cutoff
=
0.5
/
ratio
,
half_width
=
0.6
/
ratio
,
kernel_size
=
self
.
kernel_size
)
self
.
register_buffer
(
"filter"
,
filter
)
# x: [B, C, T]
def
forward
(
self
,
x
):
_
,
C
,
_
=
x
.
shape
x
=
F
.
pad
(
x
,
(
self
.
pad
,
self
.
pad
),
mode
=
"replicate"
)
x
=
self
.
ratio
*
F
.
conv_transpose1d
(
x
,
self
.
filter
.
expand
(
C
,
-
1
,
-
1
),
stride
=
self
.
stride
,
groups
=
C
)
x
=
x
[...,
self
.
pad_left
:
-
self
.
pad_right
]
return
x
class
DownSample1d
(
nn
.
Module
):
def
__init__
(
self
,
ratio
=
2
,
kernel_size
=
None
):
super
().
__init__
()
self
.
ratio
=
ratio
self
.
kernel_size
=
(
int
(
6
*
ratio
//
2
)
*
2
if
kernel_size
is
None
else
kernel_size
)
self
.
lowpass
=
LowPassFilter1d
(
cutoff
=
0.5
/
ratio
,
half_width
=
0.6
/
ratio
,
stride
=
ratio
,
kernel_size
=
self
.
kernel_size
,
)
def
forward
(
self
,
x
):
xx
=
self
.
lowpass
(
x
)
return
xx
indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
random
import
numpy
as
np
import
torchaudio
import
librosa
from
torch.nn
import
functional
as
F
from
torch.nn.utils.rnn
import
pad_sequence
from
utils.data_utils
import
*
from
models.codec.codec_dataset
import
CodecDataset
class
FAcodecDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
cfg
,
dataset
,
is_valid
=
False
):
"""
Args:
cfg: config
dataset: dataset name
is_valid: whether to use train or valid dataset
"""
self
.
data_root_dir
=
cfg
.
dataset
self
.
data_list
=
[]
# walk through the dataset directory recursively, save all files ends with .wav/.mp3/.opus/.flac/.m4a
for
root
,
_
,
files
in
os
.
walk
(
self
.
data_root_dir
):
for
file
in
files
:
if
file
.
endswith
((
".wav"
,
".mp3"
,
".opus"
,
".flac"
,
".m4a"
)):
self
.
data_list
.
append
(
os
.
path
.
join
(
root
,
file
))
self
.
sr
=
cfg
.
preprocess_params
.
sr
self
.
duration_range
=
cfg
.
preprocess_params
.
duration_range
self
.
to_mel
=
torchaudio
.
transforms
.
MelSpectrogram
(
n_mels
=
cfg
.
preprocess_params
.
spect_params
.
n_mels
,
n_fft
=
cfg
.
preprocess_params
.
spect_params
.
n_fft
,
win_length
=
cfg
.
preprocess_params
.
spect_params
.
win_length
,
hop_length
=
cfg
.
preprocess_params
.
spect_params
.
hop_length
,
)
self
.
mean
,
self
.
std
=
-
4
,
4
def
preprocess
(
self
,
wave
):
wave_tensor
=
(
torch
.
from_numpy
(
wave
).
float
()
if
isinstance
(
wave
,
np
.
ndarray
)
else
wave
)
mel_tensor
=
self
.
to_mel
(
wave_tensor
)
mel_tensor
=
(
torch
.
log
(
1e-5
+
mel_tensor
.
unsqueeze
(
0
))
-
self
.
mean
)
/
self
.
std
return
mel_tensor
def
__len__
(
self
):
# return len(self.data_list)
return
len
(
self
.
data_list
)
# return a fixed number for testing
def
__getitem__
(
self
,
index
):
wave
,
_
=
librosa
.
load
(
self
.
data_list
[
index
],
sr
=
self
.
sr
)
wave
=
np
.
random
.
randn
(
self
.
sr
*
random
.
randint
(
*
self
.
duration_range
))
wave
=
wave
/
np
.
max
(
np
.
abs
(
wave
))
mel
=
self
.
preprocess
(
wave
).
squeeze
(
0
)
wave
=
torch
.
from_numpy
(
wave
).
float
()
return
wave
,
mel
class
FAcodecCollator
(
object
):
"""Zero-pads model inputs and targets based on number of frames per step"""
def
__init__
(
self
,
cfg
):
self
.
cfg
=
cfg
def
__call__
(
self
,
batch
):
# batch[0] = wave, mel, text, f0, speakerid
batch_size
=
len
(
batch
)
# sort by mel length
lengths
=
[
b
[
1
].
shape
[
1
]
for
b
in
batch
]
batch_indexes
=
np
.
argsort
(
lengths
)[::
-
1
]
batch
=
[
batch
[
bid
]
for
bid
in
batch_indexes
]
nmels
=
batch
[
0
][
1
].
size
(
0
)
max_mel_length
=
max
([
b
[
1
].
shape
[
1
]
for
b
in
batch
])
max_wave_length
=
max
([
b
[
0
].
size
(
0
)
for
b
in
batch
])
mels
=
torch
.
zeros
((
batch_size
,
nmels
,
max_mel_length
)).
float
()
-
10
waves
=
torch
.
zeros
((
batch_size
,
max_wave_length
)).
float
()
mel_lengths
=
torch
.
zeros
(
batch_size
).
long
()
wave_lengths
=
torch
.
zeros
(
batch_size
).
long
()
for
bid
,
(
wave
,
mel
)
in
enumerate
(
batch
):
mel_size
=
mel
.
size
(
1
)
mels
[
bid
,
:,
:
mel_size
]
=
mel
waves
[
bid
,
:
wave
.
size
(
0
)]
=
wave
mel_lengths
[
bid
]
=
mel_size
wave_lengths
[
bid
]
=
wave
.
size
(
0
)
return
waves
,
mels
,
wave_lengths
,
mel_lengths
indextts/utils/maskgct/models/codec/facodec/facodec_inference.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
shutil
import
warnings
import
argparse
import
torch
import
os
import
yaml
warnings
.
simplefilter
(
"ignore"
)
from
.modules.commons
import
*
import
time
import
torchaudio
import
librosa
from
collections
import
OrderedDict
class
FAcodecInference
(
object
):
def
__init__
(
self
,
args
=
None
,
cfg
=
None
):
self
.
args
=
args
self
.
cfg
=
cfg
self
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
model
=
self
.
_build_model
()
self
.
_load_checkpoint
()
def
_build_model
(
self
):
model
=
build_model
(
self
.
cfg
.
model_params
)
_
=
[
model
[
key
].
to
(
self
.
device
)
for
key
in
model
]
return
model
def
_load_checkpoint
(
self
):
sd
=
torch
.
load
(
self
.
args
.
checkpoint_path
,
map_location
=
"cpu"
)
sd
=
sd
[
"net"
]
if
"net"
in
sd
else
sd
new_params
=
dict
()
for
key
,
state_dict
in
sd
.
items
():
new_state_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
if
k
.
startswith
(
"module."
):
k
=
k
[
7
:]
new_state_dict
[
k
]
=
v
new_params
[
key
]
=
new_state_dict
for
key
in
new_params
:
if
key
in
self
.
model
:
self
.
model
[
key
].
load_state_dict
(
new_params
[
key
])
_
=
[
self
.
model
[
key
].
eval
()
for
key
in
self
.
model
]
@
torch
.
no_grad
()
def
inference
(
self
,
source
,
output_dir
):
source_audio
=
librosa
.
load
(
source
,
sr
=
self
.
cfg
.
preprocess_params
.
sr
)[
0
]
source_audio
=
torch
.
tensor
(
source_audio
).
unsqueeze
(
0
).
float
().
to
(
self
.
device
)
z
=
self
.
model
.
encoder
(
source_audio
[
None
,
...].
to
(
self
.
device
).
float
())
(
z
,
quantized
,
commitment_loss
,
codebook_loss
,
timbre
,
codes
,
)
=
self
.
model
.
quantizer
(
z
,
source_audio
[
None
,
...].
to
(
self
.
device
).
float
(),
n_c
=
self
.
cfg
.
model_params
.
n_c_codebooks
,
return_codes
=
True
,
)
full_pred_wave
=
self
.
model
.
decoder
(
z
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
source_name
=
source
.
split
(
"/"
)[
-
1
].
split
(
"."
)[
0
]
torchaudio
.
save
(
f
"
{
output_dir
}
/reconstructed_
{
source_name
}
.wav"
,
full_pred_wave
[
0
].
cpu
(),
self
.
cfg
.
preprocess_params
.
sr
,
)
print
(
"Reconstructed audio saved as: "
,
f
"
{
output_dir
}
/reconstructed_
{
source_name
}
.wav"
,
)
return
quantized
,
codes
@
torch
.
no_grad
()
def
voice_conversion
(
self
,
source
,
reference
,
output_dir
):
source_audio
=
librosa
.
load
(
source
,
sr
=
self
.
cfg
.
preprocess_params
.
sr
)[
0
]
source_audio
=
torch
.
tensor
(
source_audio
).
unsqueeze
(
0
).
float
().
to
(
self
.
device
)
reference_audio
=
librosa
.
load
(
reference
,
sr
=
self
.
cfg
.
preprocess_params
.
sr
)[
0
]
reference_audio
=
(
torch
.
tensor
(
reference_audio
).
unsqueeze
(
0
).
float
().
to
(
self
.
device
)
)
z
=
self
.
model
.
encoder
(
source_audio
[
None
,
...].
to
(
self
.
device
).
float
())
z
,
quantized
,
commitment_loss
,
codebook_loss
,
timbre
=
self
.
model
.
quantizer
(
z
,
source_audio
[
None
,
...].
to
(
self
.
device
).
float
(),
n_c
=
self
.
cfg
.
model_params
.
n_c_codebooks
,
)
z_ref
=
self
.
model
.
encoder
(
reference_audio
[
None
,
...].
to
(
self
.
device
).
float
())
(
z_ref
,
quantized_ref
,
commitment_loss_ref
,
codebook_loss_ref
,
timbre_ref
,
)
=
self
.
model
.
quantizer
(
z_ref
,
reference_audio
[
None
,
...].
to
(
self
.
device
).
float
(),
n_c
=
self
.
cfg
.
model_params
.
n_c_codebooks
,
)
z_conv
=
self
.
model
.
quantizer
.
voice_conversion
(
quantized
[
0
]
+
quantized
[
1
],
reference_audio
[
None
,
...].
to
(
self
.
device
).
float
(),
)
full_pred_wave
=
self
.
model
.
decoder
(
z_conv
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
source_name
=
source
.
split
(
"/"
)[
-
1
].
split
(
"."
)[
0
]
reference_name
=
reference
.
split
(
"/"
)[
-
1
].
split
(
"."
)[
0
]
torchaudio
.
save
(
f
"
{
output_dir
}
/converted_
{
source_name
}
_to_
{
reference_name
}
.wav"
,
full_pred_wave
[
0
].
cpu
(),
self
.
cfg
.
preprocess_params
.
sr
,
)
print
(
"Voice conversion results saved as: "
,
f
"
{
output_dir
}
/converted_
{
source_name
}
_to_
{
reference_name
}
.wav"
,
)
indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
time
import
random
from
pathlib
import
Path
import
re
import
glob
import
accelerate
import
json
import
numpy
as
np
import
torch
from
accelerate.utils
import
ProjectConfiguration
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
import
torch
import
torch.nn.functional
as
F
import
torchaudio
from
accelerate.logging
import
get_logger
from
models.codec.facodec.facodec_dataset
import
FAcodecDataset
,
FAcodecCollator
from
models.codec.codec_sampler
import
build_samplers
from
models.codec.codec_trainer
import
CodecTrainer
from
modules.dac.nn.loss
import
(
MultiScaleSTFTLoss
,
MelSpectrogramLoss
,
GANLoss
,
L1Loss
,
FocalLoss
,
)
from
audiotools
import
AudioSignal
from
transformers
import
Wav2Vec2Processor
,
Wav2Vec2ForCTC
try
:
import
nemo.collections.asr
as
nemo_asr
except
ImportError
:
print
(
"Unable to import nemo_asr, titanet outputs will be set to random values, you may only run debugging mode. DO NOT USE THIS FOR TRAINING"
)
nemo_asr
=
None
from
models.codec.facodec.modules.commons
import
(
build_model
,
load_checkpoint
,
load_F0_models
,
log_norm
,
)
from
models.codec.facodec.optimizer
import
build_optimizer
class
FAcodecTrainer
(
CodecTrainer
):
def
__init__
(
self
,
args
,
cfg
):
super
().
__init__
()
self
.
args
=
args
self
.
cfg
=
cfg
cfg
.
exp_name
=
args
.
exp_name
# Init accelerator
self
.
_init_accelerator
()
self
.
accelerator
.
wait_for_everyone
()
# Init logger
with
self
.
accelerator
.
main_process_first
():
self
.
logger
=
get_logger
(
args
.
exp_name
,
log_level
=
args
.
log_level
)
self
.
logger
.
info
(
"="
*
56
)
self
.
logger
.
info
(
"||
\t\t
"
+
"New training process started."
+
"
\t\t
||"
)
self
.
logger
.
info
(
"="
*
56
)
self
.
logger
.
info
(
"
\n
"
)
self
.
logger
.
debug
(
f
"Using
{
args
.
log_level
.
upper
()
}
logging level."
)
self
.
logger
.
info
(
f
"Experiment name:
{
args
.
exp_name
}
"
)
self
.
logger
.
info
(
f
"Experiment directory:
{
self
.
exp_dir
}
"
)
self
.
checkpoint_dir
=
os
.
path
.
join
(
self
.
exp_dir
,
"checkpoint"
)
if
self
.
accelerator
.
is_main_process
:
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
self
.
logger
.
debug
(
f
"Checkpoint directory:
{
self
.
checkpoint_dir
}
"
)
# Init training status
self
.
batch_count
:
int
=
0
self
.
step
:
int
=
0
self
.
epoch
:
int
=
0
self
.
max_epoch
=
(
self
.
cfg
.
train
.
max_epoch
if
self
.
cfg
.
train
.
max_epoch
>
0
else
float
(
"inf"
)
)
self
.
logger
.
info
(
"Max epoch: {}"
.
format
(
self
.
max_epoch
if
self
.
max_epoch
<
float
(
"inf"
)
else
"Unlimited"
)
)
# Check potential erorrs
if
self
.
accelerator
.
is_main_process
:
self
.
_check_basic_configs
()
self
.
save_checkpoint_stride
=
self
.
cfg
.
train
.
save_checkpoint_stride
self
.
checkpoints_path
=
[
[]
for
_
in
range
(
len
(
self
.
save_checkpoint_stride
))
]
self
.
run_eval
=
self
.
cfg
.
train
.
run_eval
# Set random seed
with
self
.
accelerator
.
main_process_first
():
start
=
time
.
monotonic_ns
()
self
.
_set_random_seed
(
self
.
cfg
.
train
.
random_seed
)
end
=
time
.
monotonic_ns
()
self
.
logger
.
debug
(
f
"Setting random seed done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
self
.
logger
.
debug
(
f
"Random seed:
{
self
.
cfg
.
train
.
random_seed
}
"
)
# Build dataloader
with
self
.
accelerator
.
main_process_first
():
self
.
logger
.
info
(
"Building dataset..."
)
start
=
time
.
monotonic_ns
()
self
.
train_dataloader
,
self
.
valid_dataloader
=
self
.
_build_dataloader
()
end
=
time
.
monotonic_ns
()
self
.
logger
.
info
(
f
"Building dataset done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
# Build model
with
self
.
accelerator
.
main_process_first
():
self
.
logger
.
info
(
"Building model..."
)
start
=
time
.
monotonic_ns
()
self
.
model
=
self
.
_build_model
()
end
=
time
.
monotonic_ns
()
for
_
,
model
in
self
.
model
.
items
():
self
.
logger
.
debug
(
model
)
self
.
logger
.
info
(
f
"Building model done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
self
.
logger
.
info
(
f
"Model parameters:
{
self
.
_count_parameters
()
/
1e6
:.
2
f
}
M"
)
# Build optimizers and schedulers
with
self
.
accelerator
.
main_process_first
():
self
.
logger
.
info
(
"Building optimizer and scheduler..."
)
start
=
time
.
monotonic_ns
()
self
.
optimizer
=
self
.
_build_optimizer
()
end
=
time
.
monotonic_ns
()
self
.
logger
.
info
(
f
"Building optimizer and scheduler done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
# Build helper models
with
self
.
accelerator
.
main_process_first
():
self
.
logger
.
info
(
"Building helper models..."
)
start
=
time
.
monotonic_ns
()
self
.
_built_helper_model
()
end
=
time
.
monotonic_ns
()
self
.
logger
.
info
(
f
"Building helper models done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
# Accelerator preparing
self
.
logger
.
info
(
"Initializing accelerate..."
)
start
=
time
.
monotonic_ns
()
for
k
in
self
.
model
:
self
.
model
[
k
]
=
self
.
accelerator
.
prepare
(
self
.
model
[
k
])
for
k
,
v
in
self
.
optimizer
.
optimizers
.
items
():
self
.
optimizer
.
optimizers
[
k
]
=
self
.
accelerator
.
prepare
(
self
.
optimizer
.
optimizers
[
k
]
)
self
.
optimizer
.
schedulers
[
k
]
=
self
.
accelerator
.
prepare
(
self
.
optimizer
.
schedulers
[
k
]
)
end
=
time
.
monotonic_ns
()
self
.
logger
.
info
(
f
"Initializing accelerate done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
# Build criterions
with
self
.
accelerator
.
main_process_first
():
self
.
logger
.
info
(
"Building criterion..."
)
start
=
time
.
monotonic_ns
()
self
.
criterions
=
self
.
_build_criterion
()
end
=
time
.
monotonic_ns
()
self
.
logger
.
info
(
f
"Building criterion done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
# Resume checkpoints
with
self
.
accelerator
.
main_process_first
():
self
.
checkpoint_dir
=
os
.
path
.
join
(
self
.
exp_dir
,
"checkpoint"
)
if
args
.
resume_type
:
self
.
logger
.
info
(
"Resuming from checkpoint..."
)
start
=
time
.
monotonic_ns
()
ckpt_path
=
Path
(
args
.
checkpoint
)
if
self
.
_is_valid_pattern
(
ckpt_path
.
parts
[
-
1
]):
ckpt_path
=
self
.
_load_model
(
args
.
checkpoint
,
args
.
resume_type
)
else
:
ckpt_path
=
self
.
_load_model
(
args
.
checkpoint
,
resume_type
=
args
.
resume_type
)
end
=
time
.
monotonic_ns
()
self
.
logger
.
info
(
f
"Resuming from checkpoint done in
{
(
end
-
start
)
/
1e6
:.
2
f
}
ms"
)
self
.
checkpoints_path
=
json
.
load
(
open
(
os
.
path
.
join
(
ckpt_path
,
"ckpts.json"
),
"r"
)
)
if
self
.
accelerator
.
is_main_process
:
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
self
.
logger
.
debug
(
f
"Checkpoint directory:
{
self
.
checkpoint_dir
}
"
)
# Save config
self
.
config_save_path
=
os
.
path
.
join
(
self
.
exp_dir
,
"args.json"
)
def
_build_dataset
(
self
):
return
FAcodecDataset
,
FAcodecCollator
def
_build_criterion
(
self
):
criterions
=
dict
()
stft_criterion
=
MultiScaleSTFTLoss
()
mel_criterion
=
MelSpectrogramLoss
(
n_mels
=
[
5
,
10
,
20
,
40
,
80
,
160
,
320
],
window_lengths
=
[
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
mel_fmin
=
[
0
,
0
,
0
,
0
,
0
,
0
,
0
],
mel_fmax
=
[
None
,
None
,
None
,
None
,
None
,
None
,
None
],
pow
=
1.0
,
mag_weight
=
0.0
,
clamp_eps
=
1e-5
,
)
content_criterion
=
FocalLoss
(
gamma
=
2
)
l1_criterion
=
L1Loss
()
criterions
[
"stft"
]
=
stft_criterion
criterions
[
"mel"
]
=
mel_criterion
criterions
[
"l1"
]
=
l1_criterion
criterions
[
"content"
]
=
content_criterion
return
criterions
def
_build_model
(
self
):
model
=
build_model
(
self
.
cfg
.
model_params
)
_
=
[
model
[
key
].
to
(
self
.
accelerator
.
device
)
for
key
in
model
]
return
model
def
_built_helper_model
(
self
):
device
=
self
.
accelerator
.
device
self
.
pitch_extractor
=
load_F0_models
(
self
.
cfg
.
F0_path
).
to
(
device
)
# load model and processor
self
.
w2v_processor
=
Wav2Vec2Processor
.
from_pretrained
(
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
)
self
.
w2v_model
=
Wav2Vec2ForCTC
.
from_pretrained
(
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
).
to
(
device
)
self
.
w2v_model
.
eval
()
if
nemo_asr
is
None
:
self
.
speaker_model
=
None
else
:
self
.
speaker_model
=
(
nemo_asr
.
models
.
EncDecSpeakerLabelModel
.
from_pretrained
(
"nvidia/speakerverification_en_titanet_large"
)
)
self
.
speaker_model
=
self
.
speaker_model
.
to
(
device
)
self
.
speaker_model
.
eval
()
def
_build_optimizer
(
self
):
scheduler_params
=
{
"warmup_steps"
:
self
.
cfg
.
loss_params
.
warmup_steps
,
"base_lr"
:
self
.
cfg
.
loss_params
.
base_lr
,
}
optimizer
=
build_optimizer
(
{
key
:
self
.
model
[
key
]
for
key
in
self
.
model
},
scheduler_params_dict
=
{
key
:
scheduler_params
.
copy
()
for
key
in
self
.
model
},
lr
=
float
(
scheduler_params
[
"base_lr"
]),
)
return
optimizer
def
train_loop
(
self
):
"""Training process"""
self
.
accelerator
.
wait_for_everyone
()
# Dump config
if
self
.
accelerator
.
is_main_process
:
self
.
_dump_cfg
(
self
.
config_save_path
)
_
=
[
self
.
model
[
key
].
train
()
for
key
in
self
.
model
]
self
.
optimizer
.
zero_grad
()
# Sync and start training
self
.
accelerator
.
wait_for_everyone
()
while
self
.
epoch
<
self
.
max_epoch
:
self
.
logger
.
info
(
"
\n
"
)
self
.
logger
.
info
(
"-"
*
32
)
self
.
logger
.
info
(
"Epoch {}: "
.
format
(
self
.
epoch
))
# Train and Validate
train_total_loss
,
train_losses
=
self
.
_train_epoch
()
for
key
,
loss
in
train_losses
.
items
():
self
.
logger
.
info
(
" |- Train/{} Loss: {:.6f}"
.
format
(
key
,
loss
))
self
.
accelerator
.
log
(
{
"Epoch/Train {} Loss"
.
format
(
key
):
loss
},
step
=
self
.
epoch
,
)
self
.
accelerator
.
log
(
{
"Epoch/Train Total Loss"
:
train_total_loss
,
},
step
=
self
.
epoch
,
)
# Update scheduler
self
.
accelerator
.
wait_for_everyone
()
# Check save checkpoint interval
run_eval
=
False
if
self
.
accelerator
.
is_main_process
:
save_checkpoint
=
False
for
i
,
num
in
enumerate
(
self
.
save_checkpoint_stride
):
if
self
.
epoch
%
num
==
0
:
save_checkpoint
=
True
run_eval
|=
self
.
run_eval
[
i
]
# Save checkpoints
self
.
accelerator
.
wait_for_everyone
()
if
self
.
accelerator
.
is_main_process
and
save_checkpoint
:
print
(
"Saving.."
)
state
=
{
"net"
:
{
key
:
self
.
model
[
key
].
state_dict
()
for
key
in
self
.
model
},
"optimizer"
:
self
.
optimizer
.
state_dict
(),
"scheduler"
:
self
.
optimizer
.
scheduler_state_dict
(),
"iters"
:
self
.
step
,
"epoch"
:
self
.
epoch
,
}
save_path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
"FAcodec_epoch_%05d_step_%05d.pth"
%
(
self
.
epoch
,
self
.
iters
),
)
torch
.
save
(
state
,
save_path
)
json
.
dump
(
self
.
checkpoints_path
,
open
(
os
.
path
.
join
(
self
.
checkpoint_dir
,
"ckpts.json"
),
"w"
),
ensure_ascii
=
False
,
indent
=
4
,
)
self
.
accelerator
.
wait_for_everyone
()
self
.
epoch
+=
1
# Finish training
self
.
accelerator
.
wait_for_everyone
()
if
self
.
accelerator
.
is_main_process
:
path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch-{:04d}_step-{:07d}"
.
format
(
self
.
epoch
,
self
.
step
,
),
)
print
(
"Saving.."
)
state
=
{
"net"
:
{
key
:
self
.
model
[
key
].
state_dict
()
for
key
in
self
.
model
},
"optimizer"
:
self
.
optimizer
.
state_dict
(),
"scheduler"
:
self
.
optimizer
.
scheduler_state_dict
(),
"iters"
:
self
.
step
,
"epoch"
:
self
.
epoch
,
}
save_path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
"FAcodec_epoch_%05d_step_%05d.pth"
%
(
self
.
epoch
,
self
.
iters
),
)
torch
.
save
(
state
,
save_path
)
def
_train_epoch
(
self
):
"""Training epoch. Should return average loss of a batch (sample) over
one epoch. See ``train_loop`` for usage.
"""
_
=
[
self
.
model
[
key
].
train
()
for
key
in
self
.
model
]
epoch_losses
:
dict
=
{}
epoch_total_loss
:
int
=
0
for
batch
in
tqdm
(
self
.
train_dataloader
,
desc
=
f
"Training Epoch
{
self
.
epoch
}
"
,
unit
=
"batch"
,
colour
=
"GREEN"
,
leave
=
False
,
dynamic_ncols
=
True
,
smoothing
=
0.04
,
disable
=
not
self
.
accelerator
.
is_main_process
,
):
# Get losses
total_loss
,
losses
=
self
.
_train_step
(
batch
)
self
.
batch_count
+=
1
# Log info
if
self
.
batch_count
%
self
.
cfg
.
train
.
gradient_accumulation_step
==
0
:
self
.
accelerator
.
log
(
{
"Step/Learning Rate"
:
(
self
.
optimizer
.
schedulers
[
"encoder"
].
get_last_lr
()[
0
]
if
self
.
step
!=
0
else
0
)
},
step
=
self
.
step
,
)
for
key
,
_
in
losses
.
items
():
self
.
accelerator
.
log
(
{
"Step/Train {} Loss"
.
format
(
key
):
losses
[
key
],
},
step
=
self
.
step
,
)
if
not
epoch_losses
:
epoch_losses
=
losses
else
:
for
key
,
value
in
losses
.
items
():
epoch_losses
[
key
]
+=
value
epoch_total_loss
+=
total_loss
self
.
step
+=
1
# Get and log total losses
self
.
accelerator
.
wait_for_everyone
()
epoch_total_loss
=
(
epoch_total_loss
/
len
(
self
.
train_dataloader
)
*
self
.
cfg
.
train
.
gradient_accumulation_step
)
for
key
in
epoch_losses
.
keys
():
epoch_losses
[
key
]
=
(
epoch_losses
[
key
]
/
len
(
self
.
train_dataloader
)
*
self
.
cfg
.
train
.
gradient_accumulation_step
)
return
epoch_total_loss
,
epoch_losses
def
_train_step
(
self
,
data
):
"""Training forward step. Should return average loss of a sample over
one batch. Provoke ``_forward_step`` is recommended except for special case.
See ``_train_epoch`` for usage.
"""
# Init losses
train_losses
=
{}
total_loss
=
0
# Use input feature to get predictions
data
=
[
b
.
to
(
self
.
accelerator
.
device
,
non_blocking
=
True
)
for
b
in
data
]
waves
,
mels
,
wave_lengths
,
mel_input_length
=
data
# extract semantic latent with w2v model
waves_16k
=
torchaudio
.
functional
.
resample
(
waves
,
24000
,
16000
)
w2v_input
=
self
.
w2v_processor
(
waves_16k
,
sampling_rate
=
16000
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
accelerator
.
device
)
with
torch
.
no_grad
():
w2v_outputs
=
self
.
w2v_model
(
w2v_input
.
squeeze
(
0
)).
logits
predicted_ids
=
torch
.
argmax
(
w2v_outputs
,
dim
=-
1
)
phone_ids
=
(
F
.
interpolate
(
predicted_ids
.
unsqueeze
(
0
).
float
(),
mels
.
size
(
-
1
),
mode
=
"nearest"
)
.
long
()
.
squeeze
(
0
)
)
# get clips
mel_seg_len
=
min
(
[
int
(
mel_input_length
.
min
().
item
()),
self
.
cfg
.
train
.
max_frame_len
]
)
gt_mel_seg
=
[]
wav_seg
=
[]
w2v_seg
=
[]
for
bib
in
range
(
len
(
mel_input_length
)):
mel_length
=
int
(
mel_input_length
[
bib
].
item
())
random_start
=
(
np
.
random
.
randint
(
0
,
mel_length
-
mel_seg_len
)
if
mel_length
!=
mel_seg_len
else
0
)
gt_mel_seg
.
append
(
mels
[
bib
,
:,
random_start
:
random_start
+
mel_seg_len
])
# w2v_seg.append(w2v_latent[bib, :, random_start:random_start + mel_seg_len])
w2v_seg
.
append
(
phone_ids
[
bib
,
random_start
:
random_start
+
mel_seg_len
])
y
=
waves
[
bib
][
random_start
*
300
:
(
random_start
+
mel_seg_len
)
*
300
]
wav_seg
.
append
(
y
.
to
(
self
.
accelerator
.
device
))
gt_mel_seg
=
torch
.
stack
(
gt_mel_seg
).
detach
()
wav_seg
=
torch
.
stack
(
wav_seg
).
float
().
detach
().
unsqueeze
(
1
)
w2v_seg
=
torch
.
stack
(
w2v_seg
).
float
().
detach
()
with
torch
.
no_grad
():
real_norm
=
log_norm
(
gt_mel_seg
.
unsqueeze
(
1
)).
squeeze
(
1
).
detach
()
F0_real
,
_
,
_
=
self
.
pitch_extractor
(
gt_mel_seg
.
unsqueeze
(
1
))
# normalize f0
# Remove unvoiced frames (replace with -1)
gt_glob_f0s
=
[]
f0_targets
=
[]
for
bib
in
range
(
len
(
F0_real
)):
voiced_indices
=
F0_real
[
bib
]
>
5.0
f0_voiced
=
F0_real
[
bib
][
voiced_indices
]
if
len
(
f0_voiced
)
!=
0
:
# Convert to log scale
log_f0
=
f0_voiced
.
log2
()
# Calculate mean and standard deviation
mean_f0
=
log_f0
.
mean
()
std_f0
=
log_f0
.
std
()
# Normalize the F0 sequence
normalized_f0
=
(
log_f0
-
mean_f0
)
/
std_f0
# Create the normalized F0 sequence with unvoiced frames
normalized_sequence
=
torch
.
zeros_like
(
F0_real
[
bib
])
normalized_sequence
[
voiced_indices
]
=
normalized_f0
normalized_sequence
[
~
voiced_indices
]
=
(
-
10
)
# Assign -10 to unvoiced frames
gt_glob_f0s
.
append
(
mean_f0
)
else
:
normalized_sequence
=
torch
.
zeros_like
(
F0_real
[
bib
])
-
10.0
gt_glob_f0s
.
append
(
torch
.
tensor
(
0.0
).
to
(
self
.
accelerator
.
device
))
# f0_targets.append(normalized_sequence[single_side_context // 200:-single_side_context // 200])
f0_targets
.
append
(
normalized_sequence
)
f0_targets
=
torch
.
stack
(
f0_targets
).
to
(
self
.
accelerator
.
device
)
# fill nan with -10
f0_targets
[
torch
.
isnan
(
f0_targets
)]
=
-
10.0
# fill inf with -10
f0_targets
[
torch
.
isinf
(
f0_targets
)]
=
-
10.0
# if frame_rate not equal to 80, interpolate f0 from frame rate of 80 to target frame rate
if
self
.
cfg
.
preprocess_params
.
frame_rate
!=
80
:
f0_targets
=
F
.
interpolate
(
f0_targets
.
unsqueeze
(
1
),
mel_seg_len
//
80
*
self
.
cfg
.
preprocess_params
.
frame_rate
,
mode
=
"nearest"
,
).
squeeze
(
1
)
w2v_seg
=
F
.
interpolate
(
w2v_seg
,
mel_seg_len
//
80
*
self
.
cfg
.
preprocess_params
.
frame_rate
,
mode
=
"nearest"
,
)
wav_seg_input
=
wav_seg
wav_seg_target
=
wav_seg
z
=
self
.
model
.
encoder
(
wav_seg_input
)
z
,
quantized
,
commitment_loss
,
codebook_loss
,
timbre
=
self
.
model
.
quantizer
(
z
,
wav_seg_input
,
n_c
=
2
,
full_waves
=
waves
,
wave_lens
=
wave_lengths
)
preds
,
rev_preds
=
self
.
model
.
fa_predictors
(
quantized
,
timbre
)
pred_wave
=
self
.
model
.
decoder
(
z
)
len_diff
=
wav_seg_target
.
size
(
-
1
)
-
pred_wave
.
size
(
-
1
)
if
len_diff
>
0
:
wav_seg_target
=
wav_seg_target
[...,
len_diff
//
2
:
-
len_diff
//
2
]
# discriminator loss
d_fake
=
self
.
model
.
discriminator
(
pred_wave
.
detach
())
d_real
=
self
.
model
.
discriminator
(
wav_seg_target
)
loss_d
=
0
for
x_fake
,
x_real
in
zip
(
d_fake
,
d_real
):
loss_d
+=
torch
.
mean
(
x_fake
[
-
1
]
**
2
)
loss_d
+=
torch
.
mean
((
1
-
x_real
[
-
1
])
**
2
)
self
.
optimizer
.
zero_grad
()
self
.
accelerator
.
backward
(
loss_d
)
grad_norm_d
=
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
discriminator
.
parameters
(),
10.0
)
self
.
optimizer
.
step
(
"discriminator"
)
self
.
optimizer
.
scheduler
(
key
=
"discriminator"
)
# generator loss
signal
=
AudioSignal
(
wav_seg_target
,
sample_rate
=
24000
)
recons
=
AudioSignal
(
pred_wave
,
sample_rate
=
24000
)
stft_loss
=
self
.
criterions
[
"stft"
](
recons
,
signal
)
mel_loss
=
self
.
criterions
[
"mel"
](
recons
,
signal
)
waveform_loss
=
self
.
criterions
[
"l1"
](
recons
,
signal
)
d_fake
=
self
.
model
.
discriminator
(
pred_wave
)
d_real
=
self
.
model
.
discriminator
(
wav_seg_target
)
loss_g
=
0
for
x_fake
in
d_fake
:
loss_g
+=
torch
.
mean
((
1
-
x_fake
[
-
1
])
**
2
)
loss_feature
=
0
for
i
in
range
(
len
(
d_fake
)):
for
j
in
range
(
len
(
d_fake
[
i
])
-
1
):
loss_feature
+=
F
.
l1_loss
(
d_fake
[
i
][
j
],
d_real
[
i
][
j
].
detach
())
pred_f0
,
pred_uv
=
preds
[
"f0"
],
preds
[
"uv"
]
rev_pred_f0
,
rev_pred_uv
=
rev_preds
[
"rev_f0"
],
rev_preds
[
"rev_uv"
]
common_min_size
=
min
(
pred_f0
.
size
(
-
2
),
f0_targets
.
size
(
-
1
))
f0_targets
=
f0_targets
[...,
:
common_min_size
]
real_norm
=
real_norm
[...,
:
common_min_size
]
f0_loss
=
F
.
smooth_l1_loss
(
f0_targets
,
pred_f0
.
squeeze
(
-
1
)[...,
:
common_min_size
]
)
uv_loss
=
F
.
smooth_l1_loss
(
real_norm
,
pred_uv
.
squeeze
(
-
1
)[...,
:
common_min_size
]
)
rev_f0_loss
=
(
F
.
smooth_l1_loss
(
f0_targets
,
rev_pred_f0
.
squeeze
(
-
1
)[...,
:
common_min_size
])
if
rev_pred_f0
is
not
None
else
torch
.
FloatTensor
([
0
]).
to
(
self
.
accelerator
.
device
)
)
rev_uv_loss
=
(
F
.
smooth_l1_loss
(
real_norm
,
rev_pred_uv
.
squeeze
(
-
1
)[...,
:
common_min_size
])
if
rev_pred_uv
is
not
None
else
torch
.
FloatTensor
([
0
]).
to
(
self
.
accelerator
.
device
)
)
tot_f0_loss
=
f0_loss
+
rev_f0_loss
tot_uv_loss
=
uv_loss
+
rev_uv_loss
pred_content
=
preds
[
"content"
]
rev_pred_content
=
rev_preds
[
"rev_content"
]
target_content_latents
=
w2v_seg
[...,
:
common_min_size
]
content_loss
=
self
.
criterions
[
"content"
](
pred_content
.
transpose
(
1
,
2
)[...,
:
common_min_size
],
target_content_latents
.
long
(),
)
rev_content_loss
=
(
self
.
criterions
[
"content"
](
rev_pred_content
.
transpose
(
1
,
2
)[...,
:
common_min_size
],
target_content_latents
.
long
(),
)
if
rev_pred_content
is
not
None
else
torch
.
FloatTensor
([
0
]).
to
(
self
.
accelerator
.
device
)
)
tot_content_loss
=
content_loss
+
rev_content_loss
if
self
.
speaker_model
is
not
None
:
spk_logits
=
torch
.
cat
(
[
self
.
speaker_model
.
infer_segment
(
w16
.
cpu
()[...,
:
wl
])[
1
]
for
w16
,
wl
in
zip
(
waves_16k
,
wave_lengths
)
],
dim
=
0
,
)
spk_labels
=
spk_logits
.
argmax
(
dim
=-
1
)
else
:
spk_labels
=
torch
.
zeros
([
len
(
waves_16k
)],
dtype
=
torch
.
long
).
to
(
self
.
accelerator
.
device
)
spk_pred_logits
=
preds
[
"timbre"
]
spk_loss
=
F
.
cross_entropy
(
spk_pred_logits
,
spk_labels
)
x_spk_pred_logits
=
rev_preds
[
"x_timbre"
]
x_spk_loss
=
(
F
.
cross_entropy
(
x_spk_pred_logits
,
spk_labels
)
if
x_spk_pred_logits
is
not
None
else
torch
.
FloatTensor
([
0
]).
to
(
self
.
accelerator
.
device
)
)
tot_spk_loss
=
spk_loss
+
x_spk_loss
loss_gen_all
=
(
mel_loss
*
15.0
+
loss_feature
*
1.0
+
loss_g
*
1.0
+
commitment_loss
*
0.25
+
codebook_loss
*
1.0
+
tot_f0_loss
*
1.0
+
tot_uv_loss
*
1.0
+
tot_content_loss
*
5.0
+
tot_spk_loss
*
5.0
)
self
.
optimizer
.
zero_grad
()
self
.
accelerator
.
backward
(
loss_gen_all
)
with
torch
.
no_grad
():
total_loss
=
loss_gen_all
.
item
()
train_losses
[
"stft"
]
=
stft_loss
.
item
()
train_losses
[
"mel"
]
=
mel_loss
.
item
()
train_losses
[
"l1"
]
=
waveform_loss
.
item
()
train_losses
[
"f0"
]
=
f0_loss
.
item
()
train_losses
[
"uv"
]
=
uv_loss
.
item
()
train_losses
[
"content"
]
=
content_loss
.
item
()
train_losses
[
"speaker"
]
=
spk_loss
.
item
()
train_losses
[
"rev_f0"
]
=
rev_f0_loss
.
item
()
train_losses
[
"rev_uv"
]
=
rev_uv_loss
.
item
()
train_losses
[
"rev_content"
]
=
rev_content_loss
.
item
()
train_losses
[
"rev_speaker"
]
=
x_spk_loss
.
item
()
train_losses
[
"feature"
]
=
loss_feature
.
item
()
train_losses
[
"generator"
]
=
loss_g
.
item
()
train_losses
[
"commitment"
]
=
commitment_loss
.
item
()
train_losses
[
"codebook"
]
=
codebook_loss
.
item
()
# discriminators
train_losses
[
"discriminator"
]
=
loss_d
.
item
()
return
total_loss
,
train_losses
def
_inference
(
self
,
eval_wave
):
"""Inference during training for test audios."""
z
=
self
.
model
.
encoder
(
eval_wave
[
None
,
None
,
...].
to
(
self
.
accelerator
.
device
).
float
()
)
z
,
quantized
,
commitment_loss
,
codebook_loss
,
timbre
=
self
.
model
.
quantizer
(
z
,
eval_wave
[
None
,
None
,
...],
n_c
=
self
.
cfg
.
model_params
.
n_c_codebooks
)
full_pred_wave
=
self
.
model
.
decoder
(
z
)
return
full_pred_wave
[
0
]
def
_load_model
(
self
,
checkpoint_path
=
None
,
resume_type
=
"resume"
):
"""Load model from checkpoint. If checkpoint_path is None, it will
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
None, it will load the checkpoint specified by checkpoint_path. **Only use this
method after** ``accelerator.prepare()``.
"""
if
resume_type
==
"resume"
:
if
checkpoint_path
is
None
:
available_checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
self
.
checkpoint_dir
,
"FAcodc_epoch_*_step_*.pth"
)
)
# find the checkpoint that has the highest step number
latest_checkpoint
=
max
(
available_checkpoints
,
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
1
].
split
(
"."
)[
0
]),
)
earliest_checkpoint
=
min
(
available_checkpoints
,
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
1
].
split
(
"."
)[
0
]),
)
# delete the earliest checkpoint
if
(
earliest_checkpoint
!=
latest_checkpoint
and
self
.
accelerator
.
is_main_process
and
len
(
available_checkpoints
)
>
4
):
os
.
remove
(
earliest_checkpoint
)
print
(
f
"Removed
{
earliest_checkpoint
}
"
)
else
:
latest_checkpoint
=
checkpoint_path
self
.
model
,
self
.
optimizer
,
self
.
epoch
,
self
.
step
=
load_checkpoint
(
self
.
model
,
self
.
optimizer
,
latest_checkpoint
,
load_only_params
=
False
,
ignore_modules
=
[],
is_distributed
=
self
.
accelerator
.
num_processes
>
1
,
)
else
:
raise
ValueError
(
"Invalid resume type"
)
return
checkpoint_path
def
_count_parameters
(
self
):
total_num
=
sum
(
sum
(
p
.
numel
()
for
p
in
self
.
model
[
key
].
parameters
())
for
key
in
self
.
model
)
# trainable_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
return
total_num
indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py
0 → 100644
View file @
ab9c00af
indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7
0 → 100644
View file @
ab9c00af
File added
indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is borrowed from https://github.com/yl4579/PitchExtractor/blob/main/model.py
"""
Implementation of model from:
Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
Convolutional Recurrent Neural Networks" (2019)
Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
"""
import
torch
from
torch
import
nn
class
JDCNet
(
nn
.
Module
):
"""
Joint Detection and Classification Network model for singing voice melody.
"""
def
__init__
(
self
,
num_class
=
722
,
seq_len
=
31
,
leaky_relu_slope
=
0.01
):
super
().
__init__
()
self
.
num_class
=
num_class
# input = (b, 1, 31, 513), b = batch size
self
.
conv_block
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
1
,
out_channels
=
64
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
# out: (b, 64, 31, 513)
nn
.
BatchNorm2d
(
num_features
=
64
),
nn
.
LeakyReLU
(
leaky_relu_slope
,
inplace
=
True
),
nn
.
Conv2d
(
64
,
64
,
3
,
padding
=
1
,
bias
=
False
),
# (b, 64, 31, 513)
)
# res blocks
self
.
res_block1
=
ResBlock
(
in_channels
=
64
,
out_channels
=
128
)
# (b, 128, 31, 128)
self
.
res_block2
=
ResBlock
(
in_channels
=
128
,
out_channels
=
192
)
# (b, 192, 31, 32)
self
.
res_block3
=
ResBlock
(
in_channels
=
192
,
out_channels
=
256
)
# (b, 256, 31, 8)
# pool block
self
.
pool_block
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
num_features
=
256
),
nn
.
LeakyReLU
(
leaky_relu_slope
,
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
(
1
,
4
)),
# (b, 256, 31, 2)
nn
.
Dropout
(
p
=
0.2
),
)
# maxpool layers (for auxiliary network inputs)
# in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
self
.
maxpool1
=
nn
.
MaxPool2d
(
kernel_size
=
(
1
,
40
))
# in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
self
.
maxpool2
=
nn
.
MaxPool2d
(
kernel_size
=
(
1
,
20
))
# in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
self
.
maxpool3
=
nn
.
MaxPool2d
(
kernel_size
=
(
1
,
10
))
# in = (b, 640, 31, 2), out = (b, 256, 31, 2)
self
.
detector_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
640
,
256
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
256
),
nn
.
LeakyReLU
(
leaky_relu_slope
,
inplace
=
True
),
nn
.
Dropout
(
p
=
0.2
),
)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
self
.
bilstm_classifier
=
nn
.
LSTM
(
input_size
=
512
,
hidden_size
=
256
,
batch_first
=
True
,
bidirectional
=
True
)
# (b, 31, 512)
# input: (b, 31, 512) - resized from (b, 256, 31, 2)
self
.
bilstm_detector
=
nn
.
LSTM
(
input_size
=
512
,
hidden_size
=
256
,
batch_first
=
True
,
bidirectional
=
True
)
# (b, 31, 512)
# input: (b * 31, 512)
self
.
classifier
=
nn
.
Linear
(
in_features
=
512
,
out_features
=
self
.
num_class
)
# (b * 31, num_class)
# input: (b * 31, 512)
self
.
detector
=
nn
.
Linear
(
in_features
=
512
,
out_features
=
2
)
# (b * 31, 2) - binary classifier
# initialize weights
self
.
apply
(
self
.
init_weights
)
def
get_feature_GAN
(
self
,
x
):
seq_len
=
x
.
shape
[
-
2
]
x
=
x
.
float
().
transpose
(
-
1
,
-
2
)
convblock_out
=
self
.
conv_block
(
x
)
resblock1_out
=
self
.
res_block1
(
convblock_out
)
resblock2_out
=
self
.
res_block2
(
resblock1_out
)
resblock3_out
=
self
.
res_block3
(
resblock2_out
)
poolblock_out
=
self
.
pool_block
[
0
](
resblock3_out
)
poolblock_out
=
self
.
pool_block
[
1
](
poolblock_out
)
return
poolblock_out
.
transpose
(
-
1
,
-
2
)
def
get_feature
(
self
,
x
):
seq_len
=
x
.
shape
[
-
2
]
x
=
x
.
float
().
transpose
(
-
1
,
-
2
)
convblock_out
=
self
.
conv_block
(
x
)
resblock1_out
=
self
.
res_block1
(
convblock_out
)
resblock2_out
=
self
.
res_block2
(
resblock1_out
)
resblock3_out
=
self
.
res_block3
(
resblock2_out
)
poolblock_out
=
self
.
pool_block
[
0
](
resblock3_out
)
poolblock_out
=
self
.
pool_block
[
1
](
poolblock_out
)
return
self
.
pool_block
[
2
](
poolblock_out
)
def
forward
(
self
,
x
):
"""
Returns:
classification_prediction, detection_prediction
sizes: (b, 31, 722), (b, 31, 2)
"""
###############################
# forward pass for classifier #
###############################
seq_len
=
x
.
shape
[
-
1
]
x
=
x
.
float
().
transpose
(
-
1
,
-
2
)
convblock_out
=
self
.
conv_block
(
x
)
resblock1_out
=
self
.
res_block1
(
convblock_out
)
resblock2_out
=
self
.
res_block2
(
resblock1_out
)
resblock3_out
=
self
.
res_block3
(
resblock2_out
)
poolblock_out
=
self
.
pool_block
[
0
](
resblock3_out
)
poolblock_out
=
self
.
pool_block
[
1
](
poolblock_out
)
GAN_feature
=
poolblock_out
.
transpose
(
-
1
,
-
2
)
poolblock_out
=
self
.
pool_block
[
2
](
poolblock_out
)
# (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
classifier_out
=
(
poolblock_out
.
permute
(
0
,
2
,
1
,
3
).
contiguous
().
view
((
-
1
,
seq_len
,
512
))
)
classifier_out
,
_
=
self
.
bilstm_classifier
(
classifier_out
)
# ignore the hidden states
classifier_out
=
classifier_out
.
contiguous
().
view
((
-
1
,
512
))
# (b * 31, 512)
classifier_out
=
self
.
classifier
(
classifier_out
)
classifier_out
=
classifier_out
.
view
(
(
-
1
,
seq_len
,
self
.
num_class
)
)
# (b, 31, num_class)
# sizes: (b, 31, 722), (b, 31, 2)
# classifier output consists of predicted pitch classes per frame
# detector output consists of: (isvoice, notvoice) estimates per frame
return
torch
.
abs
(
classifier_out
.
squeeze
(
-
1
)),
GAN_feature
,
poolblock_out
@
staticmethod
def
init_weights
(
m
):
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
kaiming_uniform_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
xavier_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
nn
.
LSTM
)
or
isinstance
(
m
,
nn
.
LSTMCell
):
for
p
in
m
.
parameters
():
if
p
.
data
is
None
:
continue
if
len
(
p
.
shape
)
>=
2
:
nn
.
init
.
orthogonal_
(
p
.
data
)
else
:
nn
.
init
.
normal_
(
p
.
data
)
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
leaky_relu_slope
=
0.01
):
super
().
__init__
()
self
.
downsample
=
in_channels
!=
out_channels
# BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
self
.
pre_conv
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
num_features
=
in_channels
),
nn
.
LeakyReLU
(
leaky_relu_slope
,
inplace
=
True
),
nn
.
MaxPool2d
(
kernel_size
=
(
1
,
2
)),
# apply downsampling on the y axis only
)
# conv layers
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
,
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
LeakyReLU
(
leaky_relu_slope
,
inplace
=
True
),
nn
.
Conv2d
(
out_channels
,
out_channels
,
3
,
padding
=
1
,
bias
=
False
),
)
# 1 x 1 convolution layer to match the feature dimensions
self
.
conv1by1
=
None
if
self
.
downsample
:
self
.
conv1by1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
1
,
bias
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
pre_conv
(
x
)
if
self
.
downsample
:
x
=
self
.
conv
(
x
)
+
self
.
conv1by1
(
x
)
else
:
x
=
self
.
conv
(
x
)
+
x
return
x
indextts/utils/maskgct/models/codec/facodec/modules/attentions.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/attentions.py
import
copy
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
.
import
commons
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
eps
=
1e-5
):
super
().
__init__
()
self
.
channels
=
channels
self
.
eps
=
eps
self
.
gamma
=
nn
.
Parameter
(
torch
.
ones
(
channels
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
channels
))
def
forward
(
self
,
x
):
x
=
x
.
transpose
(
1
,
-
1
)
x
=
F
.
layer_norm
(
x
,
(
self
.
channels
,),
self
.
gamma
,
self
.
beta
,
self
.
eps
)
return
x
.
transpose
(
1
,
-
1
)
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
kernel_size
=
1
,
p_dropout
=
0.0
,
window_size
=
4
,
**
kwargs
):
super
().
__init__
()
self
.
hidden_channels
=
hidden_channels
self
.
filter_channels
=
filter_channels
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
drop
=
nn
.
Dropout
(
p_dropout
)
self
.
attn_layers
=
nn
.
ModuleList
()
self
.
norm_layers_1
=
nn
.
ModuleList
()
self
.
ffn_layers
=
nn
.
ModuleList
()
self
.
norm_layers_2
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
n_layers
):
self
.
attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
p_dropout
=
p_dropout
,
window_size
=
window_size
,
)
)
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
ffn_layers
.
append
(
FFN
(
hidden_channels
,
hidden_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
,
)
)
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
def
forward
(
self
,
x
,
x_mask
):
attn_mask
=
x_mask
.
unsqueeze
(
2
)
*
x_mask
.
unsqueeze
(
-
1
)
x
=
x
*
x_mask
for
i
in
range
(
self
.
n_layers
):
y
=
self
.
attn_layers
[
i
](
x
,
x
,
attn_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_1
[
i
](
x
+
y
)
y
=
self
.
ffn_layers
[
i
](
x
,
x_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_2
[
i
](
x
+
y
)
x
=
x
*
x_mask
return
x
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
filter_channels
,
n_heads
,
n_layers
,
kernel_size
=
1
,
p_dropout
=
0.0
,
proximal_bias
=
False
,
proximal_init
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
hidden_channels
=
hidden_channels
self
.
filter_channels
=
filter_channels
self
.
n_heads
=
n_heads
self
.
n_layers
=
n_layers
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
proximal_bias
=
proximal_bias
self
.
proximal_init
=
proximal_init
self
.
drop
=
nn
.
Dropout
(
p_dropout
)
self
.
self_attn_layers
=
nn
.
ModuleList
()
self
.
norm_layers_0
=
nn
.
ModuleList
()
self
.
encdec_attn_layers
=
nn
.
ModuleList
()
self
.
norm_layers_1
=
nn
.
ModuleList
()
self
.
ffn_layers
=
nn
.
ModuleList
()
self
.
norm_layers_2
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
n_layers
):
self
.
self_attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
p_dropout
=
p_dropout
,
proximal_bias
=
proximal_bias
,
proximal_init
=
proximal_init
,
)
)
self
.
norm_layers_0
.
append
(
LayerNorm
(
hidden_channels
))
self
.
encdec_attn_layers
.
append
(
MultiHeadAttention
(
hidden_channels
,
hidden_channels
,
n_heads
,
p_dropout
=
p_dropout
)
)
self
.
norm_layers_1
.
append
(
LayerNorm
(
hidden_channels
))
self
.
ffn_layers
.
append
(
FFN
(
hidden_channels
,
hidden_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
p_dropout
,
causal
=
True
,
)
)
self
.
norm_layers_2
.
append
(
LayerNorm
(
hidden_channels
))
def
forward
(
self
,
x
,
x_mask
,
h
,
h_mask
):
"""
x: decoder input
h: encoder output
"""
self_attn_mask
=
commons
.
subsequent_mask
(
x_mask
.
size
(
2
)).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
encdec_attn_mask
=
h_mask
.
unsqueeze
(
2
)
*
x_mask
.
unsqueeze
(
-
1
)
x
=
x
*
x_mask
for
i
in
range
(
self
.
n_layers
):
y
=
self
.
self_attn_layers
[
i
](
x
,
x
,
self_attn_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_0
[
i
](
x
+
y
)
y
=
self
.
encdec_attn_layers
[
i
](
x
,
h
,
encdec_attn_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_1
[
i
](
x
+
y
)
y
=
self
.
ffn_layers
[
i
](
x
,
x_mask
)
y
=
self
.
drop
(
y
)
x
=
self
.
norm_layers_2
[
i
](
x
+
y
)
x
=
x
*
x_mask
return
x
class
MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
out_channels
,
n_heads
,
p_dropout
=
0.0
,
window_size
=
None
,
heads_share
=
True
,
block_length
=
None
,
proximal_bias
=
False
,
proximal_init
=
False
,
):
super
().
__init__
()
assert
channels
%
n_heads
==
0
self
.
channels
=
channels
self
.
out_channels
=
out_channels
self
.
n_heads
=
n_heads
self
.
p_dropout
=
p_dropout
self
.
window_size
=
window_size
self
.
heads_share
=
heads_share
self
.
block_length
=
block_length
self
.
proximal_bias
=
proximal_bias
self
.
proximal_init
=
proximal_init
self
.
attn
=
None
self
.
k_channels
=
channels
//
n_heads
self
.
conv_q
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_k
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_v
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
conv_o
=
nn
.
Conv1d
(
channels
,
out_channels
,
1
)
self
.
drop
=
nn
.
Dropout
(
p_dropout
)
if
window_size
is
not
None
:
n_heads_rel
=
1
if
heads_share
else
n_heads
rel_stddev
=
self
.
k_channels
**-
0.5
self
.
emb_rel_k
=
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
self
.
emb_rel_v
=
nn
.
Parameter
(
torch
.
randn
(
n_heads_rel
,
window_size
*
2
+
1
,
self
.
k_channels
)
*
rel_stddev
)
nn
.
init
.
xavier_uniform_
(
self
.
conv_q
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
conv_k
.
weight
)
nn
.
init
.
xavier_uniform_
(
self
.
conv_v
.
weight
)
if
proximal_init
:
with
torch
.
no_grad
():
self
.
conv_k
.
weight
.
copy_
(
self
.
conv_q
.
weight
)
self
.
conv_k
.
bias
.
copy_
(
self
.
conv_q
.
bias
)
def
forward
(
self
,
x
,
c
,
attn_mask
=
None
):
q
=
self
.
conv_q
(
x
)
k
=
self
.
conv_k
(
c
)
v
=
self
.
conv_v
(
c
)
x
,
self
.
attn
=
self
.
attention
(
q
,
k
,
v
,
mask
=
attn_mask
)
x
=
self
.
conv_o
(
x
)
return
x
def
attention
(
self
,
query
,
key
,
value
,
mask
=
None
):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b
,
d
,
t_s
,
t_t
=
(
*
key
.
size
(),
query
.
size
(
2
))
query
=
query
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_t
).
transpose
(
2
,
3
)
key
=
key
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
value
=
value
.
view
(
b
,
self
.
n_heads
,
self
.
k_channels
,
t_s
).
transpose
(
2
,
3
)
scores
=
torch
.
matmul
(
query
/
math
.
sqrt
(
self
.
k_channels
),
key
.
transpose
(
-
2
,
-
1
))
if
self
.
window_size
is
not
None
:
assert
(
t_s
==
t_t
),
"Relative attention is only available for self-attention."
key_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_k
,
t_s
)
rel_logits
=
self
.
_matmul_with_relative_keys
(
query
/
math
.
sqrt
(
self
.
k_channels
),
key_relative_embeddings
)
scores_local
=
self
.
_relative_position_to_absolute_position
(
rel_logits
)
scores
=
scores
+
scores_local
if
self
.
proximal_bias
:
assert
t_s
==
t_t
,
"Proximal bias is only available for self-attention."
scores
=
scores
+
self
.
_attention_bias_proximal
(
t_s
).
to
(
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
if
mask
is
not
None
:
scores
=
scores
.
masked_fill
(
mask
==
0
,
-
1e4
)
if
self
.
block_length
is
not
None
:
assert
(
t_s
==
t_t
),
"Local attention is only available for self-attention."
block_mask
=
(
torch
.
ones_like
(
scores
)
.
triu
(
-
self
.
block_length
)
.
tril
(
self
.
block_length
)
)
scores
=
scores
.
masked_fill
(
block_mask
==
0
,
-
1e4
)
p_attn
=
F
.
softmax
(
scores
,
dim
=-
1
)
# [b, n_h, t_t, t_s]
p_attn
=
self
.
drop
(
p_attn
)
output
=
torch
.
matmul
(
p_attn
,
value
)
if
self
.
window_size
is
not
None
:
relative_weights
=
self
.
_absolute_position_to_relative_position
(
p_attn
)
value_relative_embeddings
=
self
.
_get_relative_embeddings
(
self
.
emb_rel_v
,
t_s
)
output
=
output
+
self
.
_matmul_with_relative_values
(
relative_weights
,
value_relative_embeddings
)
output
=
(
output
.
transpose
(
2
,
3
).
contiguous
().
view
(
b
,
d
,
t_t
)
)
# [b, n_h, t_t, d_k] -> [b, d, t_t]
return
output
,
p_attn
def
_matmul_with_relative_values
(
self
,
x
,
y
):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
))
return
ret
def
_matmul_with_relative_keys
(
self
,
x
,
y
):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret
=
torch
.
matmul
(
x
,
y
.
unsqueeze
(
0
).
transpose
(
-
2
,
-
1
))
return
ret
def
_get_relative_embeddings
(
self
,
relative_embeddings
,
length
):
max_relative_position
=
2
*
self
.
window_size
+
1
# Pad first before slice to avoid using cond ops.
pad_length
=
max
(
length
-
(
self
.
window_size
+
1
),
0
)
slice_start_position
=
max
((
self
.
window_size
+
1
)
-
length
,
0
)
slice_end_position
=
slice_start_position
+
2
*
length
-
1
if
pad_length
>
0
:
padded_relative_embeddings
=
F
.
pad
(
relative_embeddings
,
commons
.
convert_pad_shape
([[
0
,
0
],
[
pad_length
,
pad_length
],
[
0
,
0
]]),
)
else
:
padded_relative_embeddings
=
relative_embeddings
used_relative_embeddings
=
padded_relative_embeddings
[
:,
slice_start_position
:
slice_end_position
]
return
used_relative_embeddings
def
_relative_position_to_absolute_position
(
self
,
x
):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch
,
heads
,
length
,
_
=
x
.
size
()
# Concat columns of pad to shift from relative to absolute indexing.
x
=
F
.
pad
(
x
,
commons
.
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
1
]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat
=
x
.
view
([
batch
,
heads
,
length
*
2
*
length
])
x_flat
=
F
.
pad
(
x_flat
,
commons
.
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]])
)
# Reshape and slice out the padded elements.
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
+
1
,
2
*
length
-
1
])[
:,
:,
:
length
,
length
-
1
:
]
return
x_final
def
_absolute_position_to_relative_position
(
self
,
x
):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch
,
heads
,
length
,
_
=
x
.
size
()
# padd along column
x
=
F
.
pad
(
x
,
commons
.
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
0
,
0
],
[
0
,
length
-
1
]])
)
x_flat
=
x
.
view
([
batch
,
heads
,
length
**
2
+
length
*
(
length
-
1
)])
# add 0's in the beginning that will skew the elements after reshape
x_flat
=
F
.
pad
(
x_flat
,
commons
.
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
length
,
0
]]))
x_final
=
x_flat
.
view
([
batch
,
heads
,
length
,
2
*
length
])[:,
:,
:,
1
:]
return
x_final
def
_attention_bias_proximal
(
self
,
length
):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r
=
torch
.
arange
(
length
,
dtype
=
torch
.
float32
)
diff
=
torch
.
unsqueeze
(
r
,
0
)
-
torch
.
unsqueeze
(
r
,
1
)
return
torch
.
unsqueeze
(
torch
.
unsqueeze
(
-
torch
.
log1p
(
torch
.
abs
(
diff
)),
0
),
0
)
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
filter_channels
,
kernel_size
,
p_dropout
=
0.0
,
activation
=
None
,
causal
=
False
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
filter_channels
=
filter_channels
self
.
kernel_size
=
kernel_size
self
.
p_dropout
=
p_dropout
self
.
activation
=
activation
self
.
causal
=
causal
if
causal
:
self
.
padding
=
self
.
_causal_padding
else
:
self
.
padding
=
self
.
_same_padding
self
.
conv_1
=
nn
.
Conv1d
(
in_channels
,
filter_channels
,
kernel_size
)
self
.
conv_2
=
nn
.
Conv1d
(
filter_channels
,
out_channels
,
kernel_size
)
self
.
drop
=
nn
.
Dropout
(
p_dropout
)
def
forward
(
self
,
x
,
x_mask
):
x
=
self
.
conv_1
(
self
.
padding
(
x
*
x_mask
))
if
self
.
activation
==
"gelu"
:
x
=
x
*
torch
.
sigmoid
(
1.702
*
x
)
else
:
x
=
torch
.
relu
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
conv_2
(
self
.
padding
(
x
*
x_mask
))
return
x
*
x_mask
def
_causal_padding
(
self
,
x
):
if
self
.
kernel_size
==
1
:
return
x
pad_l
=
self
.
kernel_size
-
1
pad_r
=
0
padding
=
[[
0
,
0
],
[
0
,
0
],
[
pad_l
,
pad_r
]]
x
=
F
.
pad
(
x
,
commons
.
convert_pad_shape
(
padding
))
return
x
def
_same_padding
(
self
,
x
):
if
self
.
kernel_size
==
1
:
return
x
pad_l
=
(
self
.
kernel_size
-
1
)
//
2
pad_r
=
self
.
kernel_size
//
2
padding
=
[[
0
,
0
],
[
0
,
0
],
[
pad_l
,
pad_r
]]
x
=
F
.
pad
(
x
,
commons
.
convert_pad_shape
(
padding
))
return
x
indextts/utils/maskgct/models/codec/facodec/modules/commons.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
os.path
import
numpy
as
np
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
munch
import
Munch
import
json
class
AttrDict
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AttrDict
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
__dict__
=
self
def
init_weights
(
m
,
mean
=
0.0
,
std
=
0.01
):
classname
=
m
.
__class__
.
__name__
if
classname
.
find
(
"Conv"
)
!=
-
1
:
m
.
weight
.
data
.
normal_
(
mean
,
std
)
def
get_padding
(
kernel_size
,
dilation
=
1
):
return
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
def
convert_pad_shape
(
pad_shape
):
l
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
l
for
item
in
sublist
]
return
pad_shape
def
intersperse
(
lst
,
item
):
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
result
[
1
::
2
]
=
lst
return
result
def
kl_divergence
(
m_p
,
logs_p
,
m_q
,
logs_q
):
"""KL(P||Q)"""
kl
=
(
logs_q
-
logs_p
)
-
0.5
kl
+=
(
0.5
*
(
torch
.
exp
(
2.0
*
logs_p
)
+
((
m_p
-
m_q
)
**
2
))
*
torch
.
exp
(
-
2.0
*
logs_q
)
)
return
kl
def
rand_gumbel
(
shape
):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples
=
torch
.
rand
(
shape
)
*
0.99998
+
0.00001
return
-
torch
.
log
(
-
torch
.
log
(
uniform_samples
))
def
rand_gumbel_like
(
x
):
g
=
rand_gumbel
(
x
.
size
()).
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
g
def
slice_segments
(
x
,
ids_str
,
segment_size
=
4
):
ret
=
torch
.
zeros_like
(
x
[:,
:,
:
segment_size
])
for
i
in
range
(
x
.
size
(
0
)):
idx_str
=
ids_str
[
i
]
idx_end
=
idx_str
+
segment_size
ret
[
i
]
=
x
[
i
,
:,
idx_str
:
idx_end
]
return
ret
def
slice_segments_audio
(
x
,
ids_str
,
segment_size
=
4
):
ret
=
torch
.
zeros_like
(
x
[:,
:
segment_size
])
for
i
in
range
(
x
.
size
(
0
)):
idx_str
=
ids_str
[
i
]
idx_end
=
idx_str
+
segment_size
ret
[
i
]
=
x
[
i
,
idx_str
:
idx_end
]
return
ret
def
rand_slice_segments
(
x
,
x_lengths
=
None
,
segment_size
=
4
):
b
,
d
,
t
=
x
.
size
()
if
x_lengths
is
None
:
x_lengths
=
t
ids_str_max
=
x_lengths
-
segment_size
+
1
ids_str
=
((
torch
.
rand
([
b
]).
to
(
device
=
x
.
device
)
*
ids_str_max
).
clip
(
0
)).
to
(
dtype
=
torch
.
long
)
ret
=
slice_segments
(
x
,
ids_str
,
segment_size
)
return
ret
,
ids_str
def
get_timing_signal_1d
(
length
,
channels
,
min_timescale
=
1.0
,
max_timescale
=
1.0e4
):
position
=
torch
.
arange
(
length
,
dtype
=
torch
.
float
)
num_timescales
=
channels
//
2
log_timescale_increment
=
math
.
log
(
float
(
max_timescale
)
/
float
(
min_timescale
))
/
(
num_timescales
-
1
)
inv_timescales
=
min_timescale
*
torch
.
exp
(
torch
.
arange
(
num_timescales
,
dtype
=
torch
.
float
)
*
-
log_timescale_increment
)
scaled_time
=
position
.
unsqueeze
(
0
)
*
inv_timescales
.
unsqueeze
(
1
)
signal
=
torch
.
cat
([
torch
.
sin
(
scaled_time
),
torch
.
cos
(
scaled_time
)],
0
)
signal
=
F
.
pad
(
signal
,
[
0
,
0
,
0
,
channels
%
2
])
signal
=
signal
.
view
(
1
,
channels
,
length
)
return
signal
def
add_timing_signal_1d
(
x
,
min_timescale
=
1.0
,
max_timescale
=
1.0e4
):
b
,
channels
,
length
=
x
.
size
()
signal
=
get_timing_signal_1d
(
length
,
channels
,
min_timescale
,
max_timescale
)
return
x
+
signal
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
def
cat_timing_signal_1d
(
x
,
min_timescale
=
1.0
,
max_timescale
=
1.0e4
,
axis
=
1
):
b
,
channels
,
length
=
x
.
size
()
signal
=
get_timing_signal_1d
(
length
,
channels
,
min_timescale
,
max_timescale
)
return
torch
.
cat
([
x
,
signal
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)],
axis
)
def
subsequent_mask
(
length
):
mask
=
torch
.
tril
(
torch
.
ones
(
length
,
length
)).
unsqueeze
(
0
).
unsqueeze
(
0
)
return
mask
@
torch
.
jit
.
script
def
fused_add_tanh_sigmoid_multiply
(
input_a
,
input_b
,
n_channels
):
n_channels_int
=
n_channels
[
0
]
in_act
=
input_a
+
input_b
t_act
=
torch
.
tanh
(
in_act
[:,
:
n_channels_int
,
:])
s_act
=
torch
.
sigmoid
(
in_act
[:,
n_channels_int
:,
:])
acts
=
t_act
*
s_act
return
acts
def
convert_pad_shape
(
pad_shape
):
l
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
l
for
item
in
sublist
]
return
pad_shape
def
shift_1d
(
x
):
x
=
F
.
pad
(
x
,
convert_pad_shape
([[
0
,
0
],
[
0
,
0
],
[
1
,
0
]]))[:,
:,
:
-
1
]
return
x
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
max_length
,
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
def
generate_path
(
duration
,
mask
):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device
=
duration
.
device
b
,
_
,
t_y
,
t_x
=
mask
.
shape
cum_duration
=
torch
.
cumsum
(
duration
,
-
1
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
F
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
.
unsqueeze
(
1
).
transpose
(
2
,
3
)
*
mask
return
path
def
clip_grad_value_
(
parameters
,
clip_value
,
norm_type
=
2
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
parameters
=
list
(
filter
(
lambda
p
:
p
.
grad
is
not
None
,
parameters
))
norm_type
=
float
(
norm_type
)
if
clip_value
is
not
None
:
clip_value
=
float
(
clip_value
)
total_norm
=
0
for
p
in
parameters
:
param_norm
=
p
.
grad
.
data
.
norm
(
norm_type
)
total_norm
+=
param_norm
.
item
()
**
norm_type
if
clip_value
is
not
None
:
p
.
grad
.
data
.
clamp_
(
min
=-
clip_value
,
max
=
clip_value
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
return
total_norm
def
log_norm
(
x
,
mean
=-
4
,
std
=
4
,
dim
=
2
):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x
=
torch
.
log
(
torch
.
exp
(
x
*
std
+
mean
).
norm
(
dim
=
dim
))
return
x
from
huggingface_hub
import
hf_hub_download
def
load_F0_models
(
path
):
# load F0 model
from
.JDC.model
import
JDCNet
F0_model
=
JDCNet
(
num_class
=
1
,
seq_len
=
192
)
if
not
os
.
path
.
exists
(
path
):
path
=
hf_hub_download
(
repo_id
=
"Plachta/JDCnet"
,
filename
=
"bst.t7"
)
params
=
torch
.
load
(
path
,
map_location
=
"cpu"
)[
"net"
]
F0_model
.
load_state_dict
(
params
)
_
=
F0_model
.
train
()
return
F0_model
# Generators
from
modules.dac.model.dac
import
Encoder
,
Decoder
from
.quantize
import
FAquantizer
,
FApredictors
# Discriminators
from
modules.dac.model.discriminator
import
Discriminator
def
build_model
(
args
):
encoder
=
Encoder
(
d_model
=
args
.
DAC
.
encoder_dim
,
strides
=
args
.
DAC
.
encoder_rates
,
d_latent
=
1024
,
causal
=
args
.
causal
,
lstm
=
args
.
lstm
,
)
quantizer
=
FAquantizer
(
in_dim
=
1024
,
n_p_codebooks
=
1
,
n_c_codebooks
=
args
.
n_c_codebooks
,
n_t_codebooks
=
2
,
n_r_codebooks
=
3
,
codebook_size
=
1024
,
codebook_dim
=
8
,
quantizer_dropout
=
0.5
,
causal
=
args
.
causal
,
separate_prosody_encoder
=
args
.
separate_prosody_encoder
,
timbre_norm
=
args
.
timbre_norm
,
)
fa_predictors
=
FApredictors
(
in_dim
=
1024
,
use_gr_content_f0
=
args
.
use_gr_content_f0
,
use_gr_prosody_phone
=
args
.
use_gr_prosody_phone
,
use_gr_residual_f0
=
True
,
use_gr_residual_phone
=
True
,
use_gr_timbre_content
=
True
,
use_gr_timbre_prosody
=
args
.
use_gr_timbre_prosody
,
use_gr_x_timbre
=
True
,
norm_f0
=
args
.
norm_f0
,
timbre_norm
=
args
.
timbre_norm
,
use_gr_content_global_f0
=
args
.
use_gr_content_global_f0
,
)
decoder
=
Decoder
(
input_channel
=
1024
,
channels
=
args
.
DAC
.
decoder_dim
,
rates
=
args
.
DAC
.
decoder_rates
,
causal
=
args
.
causal
,
lstm
=
args
.
lstm
,
)
discriminator
=
Discriminator
(
rates
=
[],
periods
=
[
2
,
3
,
5
,
7
,
11
],
fft_sizes
=
[
2048
,
1024
,
512
],
sample_rate
=
args
.
DAC
.
sr
,
bands
=
[(
0.0
,
0.1
),
(
0.1
,
0.25
),
(
0.25
,
0.5
),
(
0.5
,
0.75
),
(
0.75
,
1.0
)],
)
nets
=
Munch
(
encoder
=
encoder
,
quantizer
=
quantizer
,
decoder
=
decoder
,
discriminator
=
discriminator
,
fa_predictors
=
fa_predictors
,
)
return
nets
def
load_checkpoint
(
model
,
optimizer
,
path
,
load_only_params
=
True
,
ignore_modules
=
[],
is_distributed
=
False
,
):
state
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
params
=
state
[
"net"
]
for
key
in
model
:
if
key
in
params
and
key
not
in
ignore_modules
:
if
not
is_distributed
:
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
for
k
in
list
(
params
[
key
].
keys
()):
if
k
.
startswith
(
"module."
):
params
[
key
][
k
[
len
(
"module."
)
:]]
=
params
[
key
][
k
]
del
params
[
key
][
k
]
print
(
"%s loaded"
%
key
)
model
[
key
].
load_state_dict
(
params
[
key
],
strict
=
True
)
_
=
[
model
[
key
].
eval
()
for
key
in
model
]
if
not
load_only_params
:
epoch
=
state
[
"epoch"
]
+
1
iters
=
state
[
"iters"
]
optimizer
.
load_state_dict
(
state
[
"optimizer"
])
optimizer
.
load_scheduler_state_dict
(
state
[
"scheduler"
])
else
:
epoch
=
state
[
"epoch"
]
+
1
iters
=
state
[
"iters"
]
return
model
,
optimizer
,
epoch
,
iters
def
recursive_munch
(
d
):
if
isinstance
(
d
,
dict
):
return
Munch
((
k
,
recursive_munch
(
v
))
for
k
,
v
in
d
.
items
())
elif
isinstance
(
d
,
list
):
return
[
recursive_munch
(
v
)
for
v
in
d
]
else
:
return
d
indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
torch.autograd
import
Function
import
torch
from
torch
import
nn
class
GradientReversal
(
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
alpha
):
ctx
.
save_for_backward
(
x
,
alpha
)
return
x
@
staticmethod
def
backward
(
ctx
,
grad_output
):
grad_input
=
None
_
,
alpha
=
ctx
.
saved_tensors
if
ctx
.
needs_input_grad
[
0
]:
grad_input
=
-
alpha
*
grad_output
return
grad_input
,
None
revgrad
=
GradientReversal
.
apply
class
GradientReversal
(
nn
.
Module
):
def
__init__
(
self
,
alpha
):
super
().
__init__
()
self
.
alpha
=
torch
.
tensor
(
alpha
,
requires_grad
=
False
)
def
forward
(
self
,
x
):
return
revgrad
(
x
,
self
.
alpha
)
indextts/utils/maskgct/models/codec/facodec/modules/layers.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
from
torch
import
nn
from
typing
import
Optional
,
Any
from
torch
import
Tensor
import
torch.nn.functional
as
F
import
torchaudio
import
torchaudio.functional
as
audio_F
import
random
random
.
seed
(
0
)
def
_get_activation_fn
(
activ
):
if
activ
==
"relu"
:
return
nn
.
ReLU
()
elif
activ
==
"lrelu"
:
return
nn
.
LeakyReLU
(
0.2
)
elif
activ
==
"swish"
:
return
lambda
x
:
x
*
torch
.
sigmoid
(
x
)
else
:
raise
RuntimeError
(
"Unexpected activ type %s, expected [relu, lrelu, swish]"
%
activ
)
class
LinearNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
bias
=
True
,
w_init_gain
=
"linear"
):
super
(
LinearNorm
,
self
).
__init__
()
self
.
linear_layer
=
torch
.
nn
.
Linear
(
in_dim
,
out_dim
,
bias
=
bias
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
linear_layer
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
w_init_gain
)
)
def
forward
(
self
,
x
):
return
self
.
linear_layer
(
x
)
class
ConvNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
None
,
dilation
=
1
,
bias
=
True
,
w_init_gain
=
"linear"
,
param
=
None
,
):
super
(
ConvNorm
,
self
).
__init__
()
if
padding
is
None
:
assert
kernel_size
%
2
==
1
padding
=
int
(
dilation
*
(
kernel_size
-
1
)
/
2
)
self
.
conv
=
torch
.
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
w_init_gain
,
param
=
param
),
)
def
forward
(
self
,
signal
):
conv_signal
=
self
.
conv
(
signal
)
return
conv_signal
class
CausualConv
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
1
,
dilation
=
1
,
bias
=
True
,
w_init_gain
=
"linear"
,
param
=
None
,
):
super
(
CausualConv
,
self
).
__init__
()
if
padding
is
None
:
assert
kernel_size
%
2
==
1
padding
=
int
(
dilation
*
(
kernel_size
-
1
)
/
2
)
*
2
else
:
self
.
padding
=
padding
*
2
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
self
.
padding
,
dilation
=
dilation
,
bias
=
bias
,
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
conv
.
weight
,
gain
=
torch
.
nn
.
init
.
calculate_gain
(
w_init_gain
,
param
=
param
),
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
x
[:,
:,
:
-
self
.
padding
]
return
x
class
CausualBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_conv
=
3
,
dropout_p
=
0.2
,
activ
=
"lrelu"
):
super
(
CausualBlock
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
(
[
self
.
_get_conv
(
hidden_dim
,
dilation
=
3
**
i
,
activ
=
activ
,
dropout_p
=
dropout_p
)
for
i
in
range
(
n_conv
)
]
)
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
res
=
x
x
=
block
(
x
)
x
+=
res
return
x
def
_get_conv
(
self
,
hidden_dim
,
dilation
,
activ
=
"lrelu"
,
dropout_p
=
0.2
):
layers
=
[
CausualConv
(
hidden_dim
,
hidden_dim
,
kernel_size
=
3
,
padding
=
dilation
,
dilation
=
dilation
,
),
_get_activation_fn
(
activ
),
nn
.
BatchNorm1d
(
hidden_dim
),
nn
.
Dropout
(
p
=
dropout_p
),
CausualConv
(
hidden_dim
,
hidden_dim
,
kernel_size
=
3
,
padding
=
1
,
dilation
=
1
),
_get_activation_fn
(
activ
),
nn
.
Dropout
(
p
=
dropout_p
),
]
return
nn
.
Sequential
(
*
layers
)
class
ConvBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_conv
=
3
,
dropout_p
=
0.2
,
activ
=
"relu"
):
super
().
__init__
()
self
.
_n_groups
=
8
self
.
blocks
=
nn
.
ModuleList
(
[
self
.
_get_conv
(
hidden_dim
,
dilation
=
3
**
i
,
activ
=
activ
,
dropout_p
=
dropout_p
)
for
i
in
range
(
n_conv
)
]
)
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
res
=
x
x
=
block
(
x
)
x
+=
res
return
x
def
_get_conv
(
self
,
hidden_dim
,
dilation
,
activ
=
"relu"
,
dropout_p
=
0.2
):
layers
=
[
ConvNorm
(
hidden_dim
,
hidden_dim
,
kernel_size
=
3
,
padding
=
dilation
,
dilation
=
dilation
,
),
_get_activation_fn
(
activ
),
nn
.
GroupNorm
(
num_groups
=
self
.
_n_groups
,
num_channels
=
hidden_dim
),
nn
.
Dropout
(
p
=
dropout_p
),
ConvNorm
(
hidden_dim
,
hidden_dim
,
kernel_size
=
3
,
padding
=
1
,
dilation
=
1
),
_get_activation_fn
(
activ
),
nn
.
Dropout
(
p
=
dropout_p
),
]
return
nn
.
Sequential
(
*
layers
)
class
LocationLayer
(
nn
.
Module
):
def
__init__
(
self
,
attention_n_filters
,
attention_kernel_size
,
attention_dim
):
super
(
LocationLayer
,
self
).
__init__
()
padding
=
int
((
attention_kernel_size
-
1
)
/
2
)
self
.
location_conv
=
ConvNorm
(
2
,
attention_n_filters
,
kernel_size
=
attention_kernel_size
,
padding
=
padding
,
bias
=
False
,
stride
=
1
,
dilation
=
1
,
)
self
.
location_dense
=
LinearNorm
(
attention_n_filters
,
attention_dim
,
bias
=
False
,
w_init_gain
=
"tanh"
)
def
forward
(
self
,
attention_weights_cat
):
processed_attention
=
self
.
location_conv
(
attention_weights_cat
)
processed_attention
=
processed_attention
.
transpose
(
1
,
2
)
processed_attention
=
self
.
location_dense
(
processed_attention
)
return
processed_attention
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
attention_rnn_dim
,
embedding_dim
,
attention_dim
,
attention_location_n_filters
,
attention_location_kernel_size
,
):
super
(
Attention
,
self
).
__init__
()
self
.
query_layer
=
LinearNorm
(
attention_rnn_dim
,
attention_dim
,
bias
=
False
,
w_init_gain
=
"tanh"
)
self
.
memory_layer
=
LinearNorm
(
embedding_dim
,
attention_dim
,
bias
=
False
,
w_init_gain
=
"tanh"
)
self
.
v
=
LinearNorm
(
attention_dim
,
1
,
bias
=
False
)
self
.
location_layer
=
LocationLayer
(
attention_location_n_filters
,
attention_location_kernel_size
,
attention_dim
)
self
.
score_mask_value
=
-
float
(
"inf"
)
def
get_alignment_energies
(
self
,
query
,
processed_memory
,
attention_weights_cat
):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query
=
self
.
query_layer
(
query
.
unsqueeze
(
1
))
processed_attention_weights
=
self
.
location_layer
(
attention_weights_cat
)
energies
=
self
.
v
(
torch
.
tanh
(
processed_query
+
processed_attention_weights
+
processed_memory
)
)
energies
=
energies
.
squeeze
(
-
1
)
return
energies
def
forward
(
self
,
attention_hidden_state
,
memory
,
processed_memory
,
attention_weights_cat
,
mask
,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment
=
self
.
get_alignment_energies
(
attention_hidden_state
,
processed_memory
,
attention_weights_cat
)
if
mask
is
not
None
:
alignment
.
data
.
masked_fill_
(
mask
,
self
.
score_mask_value
)
attention_weights
=
F
.
softmax
(
alignment
,
dim
=
1
)
attention_context
=
torch
.
bmm
(
attention_weights
.
unsqueeze
(
1
),
memory
)
attention_context
=
attention_context
.
squeeze
(
1
)
return
attention_context
,
attention_weights
class
ForwardAttentionV2
(
nn
.
Module
):
def
__init__
(
self
,
attention_rnn_dim
,
embedding_dim
,
attention_dim
,
attention_location_n_filters
,
attention_location_kernel_size
,
):
super
(
ForwardAttentionV2
,
self
).
__init__
()
self
.
query_layer
=
LinearNorm
(
attention_rnn_dim
,
attention_dim
,
bias
=
False
,
w_init_gain
=
"tanh"
)
self
.
memory_layer
=
LinearNorm
(
embedding_dim
,
attention_dim
,
bias
=
False
,
w_init_gain
=
"tanh"
)
self
.
v
=
LinearNorm
(
attention_dim
,
1
,
bias
=
False
)
self
.
location_layer
=
LocationLayer
(
attention_location_n_filters
,
attention_location_kernel_size
,
attention_dim
)
self
.
score_mask_value
=
-
float
(
1e20
)
def
get_alignment_energies
(
self
,
query
,
processed_memory
,
attention_weights_cat
):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query
=
self
.
query_layer
(
query
.
unsqueeze
(
1
))
processed_attention_weights
=
self
.
location_layer
(
attention_weights_cat
)
energies
=
self
.
v
(
torch
.
tanh
(
processed_query
+
processed_attention_weights
+
processed_memory
)
)
energies
=
energies
.
squeeze
(
-
1
)
return
energies
def
forward
(
self
,
attention_hidden_state
,
memory
,
processed_memory
,
attention_weights_cat
,
mask
,
log_alpha
,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
log_energy
=
self
.
get_alignment_energies
(
attention_hidden_state
,
processed_memory
,
attention_weights_cat
)
# log_energy =
if
mask
is
not
None
:
log_energy
.
data
.
masked_fill_
(
mask
,
self
.
score_mask_value
)
# attention_weights = F.softmax(alignment, dim=1)
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
# log_total_score = log_alpha + content_score
# previous_attention_weights = attention_weights_cat[:,0,:]
log_alpha_shift_padded
=
[]
max_time
=
log_energy
.
size
(
1
)
for
sft
in
range
(
2
):
shifted
=
log_alpha
[:,
:
max_time
-
sft
]
shift_padded
=
F
.
pad
(
shifted
,
(
sft
,
0
),
"constant"
,
self
.
score_mask_value
)
log_alpha_shift_padded
.
append
(
shift_padded
.
unsqueeze
(
2
))
biased
=
torch
.
logsumexp
(
torch
.
cat
(
log_alpha_shift_padded
,
2
),
2
)
log_alpha_new
=
biased
+
log_energy
attention_weights
=
F
.
softmax
(
log_alpha_new
,
dim
=
1
)
attention_context
=
torch
.
bmm
(
attention_weights
.
unsqueeze
(
1
),
memory
)
attention_context
=
attention_context
.
squeeze
(
1
)
return
attention_context
,
attention_weights
,
log_alpha_new
class
PhaseShuffle2d
(
nn
.
Module
):
def
__init__
(
self
,
n
=
2
):
super
(
PhaseShuffle2d
,
self
).
__init__
()
self
.
n
=
n
self
.
random
=
random
.
Random
(
1
)
def
forward
(
self
,
x
,
move
=
None
):
# x.size = (B, C, M, L)
if
move
is
None
:
move
=
self
.
random
.
randint
(
-
self
.
n
,
self
.
n
)
if
move
==
0
:
return
x
else
:
left
=
x
[:,
:,
:,
:
move
]
right
=
x
[:,
:,
:,
move
:]
shuffled
=
torch
.
cat
([
right
,
left
],
dim
=
3
)
return
shuffled
class
PhaseShuffle1d
(
nn
.
Module
):
def
__init__
(
self
,
n
=
2
):
super
(
PhaseShuffle1d
,
self
).
__init__
()
self
.
n
=
n
self
.
random
=
random
.
Random
(
1
)
def
forward
(
self
,
x
,
move
=
None
):
# x.size = (B, C, M, L)
if
move
is
None
:
move
=
self
.
random
.
randint
(
-
self
.
n
,
self
.
n
)
if
move
==
0
:
return
x
else
:
left
=
x
[:,
:,
:
move
]
right
=
x
[:,
:,
move
:]
shuffled
=
torch
.
cat
([
right
,
left
],
dim
=
2
)
return
shuffled
class
MFCC
(
nn
.
Module
):
def
__init__
(
self
,
n_mfcc
=
40
,
n_mels
=
80
):
super
(
MFCC
,
self
).
__init__
()
self
.
n_mfcc
=
n_mfcc
self
.
n_mels
=
n_mels
self
.
norm
=
"ortho"
dct_mat
=
audio_F
.
create_dct
(
self
.
n_mfcc
,
self
.
n_mels
,
self
.
norm
)
self
.
register_buffer
(
"dct_mat"
,
dct_mat
)
def
forward
(
self
,
mel_specgram
):
if
len
(
mel_specgram
.
shape
)
==
2
:
mel_specgram
=
mel_specgram
.
unsqueeze
(
0
)
unsqueezed
=
True
else
:
unsqueezed
=
False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc
=
torch
.
matmul
(
mel_specgram
.
transpose
(
1
,
2
),
self
.
dct_mat
).
transpose
(
1
,
2
)
# unpack batch
if
unsqueezed
:
mfcc
=
mfcc
.
squeeze
(
0
)
return
mfcc
indextts/utils/maskgct/models/codec/facodec/modules/quantize.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
modules.dac.nn.quantize
import
ResidualVectorQuantize
from
torch
import
nn
from
.wavenet
import
WN
from
.style_encoder
import
StyleEncoder
from
.gradient_reversal
import
GradientReversal
import
torch
import
torchaudio
import
torchaudio.functional
as
audio_F
import
numpy
as
np
from
..alias_free_torch
import
*
from
torch.nn.utils
import
weight_norm
from
torch
import
nn
,
sin
,
pow
from
einops.layers.torch
import
Rearrange
from
modules.dac.model.encodec
import
SConv1d
def
init_weights
(
m
):
if
isinstance
(
m
,
nn
.
Conv1d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
WNConv1d
(
*
args
,
**
kwargs
):
return
weight_norm
(
nn
.
Conv1d
(
*
args
,
**
kwargs
))
def
WNConvTranspose1d
(
*
args
,
**
kwargs
):
return
weight_norm
(
nn
.
ConvTranspose1d
(
*
args
,
**
kwargs
))
class
SnakeBeta
(
nn
.
Module
):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def
__init__
(
self
,
in_features
,
alpha
=
1.0
,
alpha_trainable
=
True
,
alpha_logscale
=
False
):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super
(
SnakeBeta
,
self
).
__init__
()
self
.
in_features
=
in_features
# initialize alpha
self
.
alpha_logscale
=
alpha_logscale
if
self
.
alpha_logscale
:
# log scale alphas initialized to zeros
self
.
alpha
=
nn
.
Parameter
(
torch
.
zeros
(
in_features
)
*
alpha
)
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
in_features
)
*
alpha
)
else
:
# linear scale alphas initialized to ones
self
.
alpha
=
nn
.
Parameter
(
torch
.
ones
(
in_features
)
*
alpha
)
self
.
beta
=
nn
.
Parameter
(
torch
.
ones
(
in_features
)
*
alpha
)
self
.
alpha
.
requires_grad
=
alpha_trainable
self
.
beta
.
requires_grad
=
alpha_trainable
self
.
no_div_by_zero
=
0.000000001
def
forward
(
self
,
x
):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta := x + 1/b * sin^2 (xa)
"""
alpha
=
self
.
alpha
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)
# line up with x to [B, C, T]
beta
=
self
.
beta
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)
if
self
.
alpha_logscale
:
alpha
=
torch
.
exp
(
alpha
)
beta
=
torch
.
exp
(
beta
)
x
=
x
+
(
1.0
/
(
beta
+
self
.
no_div_by_zero
))
*
pow
(
sin
(
x
*
alpha
),
2
)
return
x
class
ResidualUnit
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
=
16
,
dilation
:
int
=
1
):
super
().
__init__
()
pad
=
((
7
-
1
)
*
dilation
)
//
2
self
.
block
=
nn
.
Sequential
(
Activation1d
(
activation
=
SnakeBeta
(
dim
,
alpha_logscale
=
True
)),
WNConv1d
(
dim
,
dim
,
kernel_size
=
7
,
dilation
=
dilation
,
padding
=
pad
),
Activation1d
(
activation
=
SnakeBeta
(
dim
,
alpha_logscale
=
True
)),
WNConv1d
(
dim
,
dim
,
kernel_size
=
1
),
)
def
forward
(
self
,
x
):
return
x
+
self
.
block
(
x
)
class
CNNLSTM
(
nn
.
Module
):
def
__init__
(
self
,
indim
,
outdim
,
head
,
global_pred
=
False
):
super
().
__init__
()
self
.
global_pred
=
global_pred
self
.
model
=
nn
.
Sequential
(
ResidualUnit
(
indim
,
dilation
=
1
),
ResidualUnit
(
indim
,
dilation
=
2
),
ResidualUnit
(
indim
,
dilation
=
3
),
Activation1d
(
activation
=
SnakeBeta
(
indim
,
alpha_logscale
=
True
)),
Rearrange
(
"b c t -> b t c"
),
)
self
.
heads
=
nn
.
ModuleList
([
nn
.
Linear
(
indim
,
outdim
)
for
i
in
range
(
head
)])
def
forward
(
self
,
x
):
# x: [B, C, T]
x
=
self
.
model
(
x
)
if
self
.
global_pred
:
x
=
torch
.
mean
(
x
,
dim
=
1
,
keepdim
=
False
)
outs
=
[
head
(
x
)
for
head
in
self
.
heads
]
return
outs
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
max_length
,
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
class
MFCC
(
nn
.
Module
):
def
__init__
(
self
,
n_mfcc
=
40
,
n_mels
=
80
):
super
(
MFCC
,
self
).
__init__
()
self
.
n_mfcc
=
n_mfcc
self
.
n_mels
=
n_mels
self
.
norm
=
"ortho"
dct_mat
=
audio_F
.
create_dct
(
self
.
n_mfcc
,
self
.
n_mels
,
self
.
norm
)
self
.
register_buffer
(
"dct_mat"
,
dct_mat
)
def
forward
(
self
,
mel_specgram
):
if
len
(
mel_specgram
.
shape
)
==
2
:
mel_specgram
=
mel_specgram
.
unsqueeze
(
0
)
unsqueezed
=
True
else
:
unsqueezed
=
False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc
=
torch
.
matmul
(
mel_specgram
.
transpose
(
1
,
2
),
self
.
dct_mat
).
transpose
(
1
,
2
)
# unpack batch
if
unsqueezed
:
mfcc
=
mfcc
.
squeeze
(
0
)
return
mfcc
class
FAquantizer
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
=
1024
,
n_p_codebooks
=
1
,
n_c_codebooks
=
2
,
n_t_codebooks
=
2
,
n_r_codebooks
=
3
,
codebook_size
=
1024
,
codebook_dim
=
8
,
quantizer_dropout
=
0.5
,
causal
=
False
,
separate_prosody_encoder
=
False
,
timbre_norm
=
False
,
):
super
(
FAquantizer
,
self
).
__init__
()
conv1d_type
=
SConv1d
# if causal else nn.Conv1d
self
.
prosody_quantizer
=
ResidualVectorQuantize
(
input_dim
=
in_dim
,
n_codebooks
=
n_p_codebooks
,
codebook_size
=
codebook_size
,
codebook_dim
=
codebook_dim
,
quantizer_dropout
=
quantizer_dropout
,
)
self
.
content_quantizer
=
ResidualVectorQuantize
(
input_dim
=
in_dim
,
n_codebooks
=
n_c_codebooks
,
codebook_size
=
codebook_size
,
codebook_dim
=
codebook_dim
,
quantizer_dropout
=
quantizer_dropout
,
)
if
not
timbre_norm
:
self
.
timbre_quantizer
=
ResidualVectorQuantize
(
input_dim
=
in_dim
,
n_codebooks
=
n_t_codebooks
,
codebook_size
=
codebook_size
,
codebook_dim
=
codebook_dim
,
quantizer_dropout
=
quantizer_dropout
,
)
else
:
self
.
timbre_encoder
=
StyleEncoder
(
in_dim
=
80
,
hidden_dim
=
512
,
out_dim
=
in_dim
)
self
.
timbre_linear
=
nn
.
Linear
(
1024
,
1024
*
2
)
self
.
timbre_linear
.
bias
.
data
[:
1024
]
=
1
self
.
timbre_linear
.
bias
.
data
[
1024
:]
=
0
self
.
timbre_norm
=
nn
.
LayerNorm
(
1024
,
elementwise_affine
=
False
)
self
.
residual_quantizer
=
ResidualVectorQuantize
(
input_dim
=
in_dim
,
n_codebooks
=
n_r_codebooks
,
codebook_size
=
codebook_size
,
codebook_dim
=
codebook_dim
,
quantizer_dropout
=
quantizer_dropout
,
)
if
separate_prosody_encoder
:
self
.
melspec_linear
=
conv1d_type
(
in_channels
=
20
,
out_channels
=
256
,
kernel_size
=
1
,
causal
=
causal
)
self
.
melspec_encoder
=
WN
(
hidden_channels
=
256
,
kernel_size
=
5
,
dilation_rate
=
1
,
n_layers
=
8
,
gin_channels
=
0
,
p_dropout
=
0.2
,
causal
=
causal
,
)
self
.
melspec_linear2
=
conv1d_type
(
in_channels
=
256
,
out_channels
=
1024
,
kernel_size
=
1
,
causal
=
causal
)
else
:
pass
self
.
separate_prosody_encoder
=
separate_prosody_encoder
self
.
prob_random_mask_residual
=
0.75
SPECT_PARAMS
=
{
"n_fft"
:
2048
,
"win_length"
:
1200
,
"hop_length"
:
300
,
}
MEL_PARAMS
=
{
"n_mels"
:
80
,
}
self
.
to_mel
=
torchaudio
.
transforms
.
MelSpectrogram
(
n_mels
=
MEL_PARAMS
[
"n_mels"
],
sample_rate
=
24000
,
**
SPECT_PARAMS
)
self
.
mel_mean
,
self
.
mel_std
=
-
4
,
4
self
.
frame_rate
=
24000
/
300
self
.
hop_length
=
300
self
.
is_timbre_norm
=
timbre_norm
if
timbre_norm
:
self
.
forward
=
self
.
forward_v2
def
preprocess
(
self
,
wave_tensor
,
n_bins
=
20
):
mel_tensor
=
self
.
to_mel
(
wave_tensor
.
squeeze
(
1
))
mel_tensor
=
(
torch
.
log
(
1e-5
+
mel_tensor
)
-
self
.
mel_mean
)
/
self
.
mel_std
return
mel_tensor
[:,
:
n_bins
,
:
int
(
wave_tensor
.
size
(
-
1
)
/
self
.
hop_length
)]
@
torch
.
no_grad
()
def
decode
(
self
,
codes
):
code_c
,
code_p
,
code_t
=
codes
.
split
([
1
,
1
,
2
],
dim
=
1
)
z_c
=
self
.
content_quantizer
.
from_codes
(
code_c
)[
0
]
z_p
=
self
.
prosody_quantizer
.
from_codes
(
code_p
)[
0
]
z_t
=
self
.
timbre_quantizer
.
from_codes
(
code_t
)[
0
]
z
=
z_c
+
z_p
+
z_t
return
z
,
[
z_c
,
z_p
,
z_t
]
@
torch
.
no_grad
()
def
encode
(
self
,
x
,
wave_segments
,
n_c
=
1
):
outs
=
0
if
self
.
separate_prosody_encoder
:
prosody_feature
=
self
.
preprocess
(
wave_segments
)
f0_input
=
prosody_feature
# (B, T, 20)
f0_input
=
self
.
melspec_linear
(
f0_input
)
f0_input
=
self
.
melspec_encoder
(
f0_input
,
torch
.
ones
(
f0_input
.
shape
[
0
],
1
,
f0_input
.
shape
[
2
])
.
to
(
f0_input
.
device
)
.
bool
(),
)
f0_input
=
self
.
melspec_linear2
(
f0_input
)
common_min_size
=
min
(
f0_input
.
size
(
2
),
x
.
size
(
2
))
f0_input
=
f0_input
[:,
:,
:
common_min_size
]
x
=
x
[:,
:,
:
common_min_size
]
(
z_p
,
codes_p
,
latents_p
,
commitment_loss_p
,
codebook_loss_p
,
)
=
self
.
prosody_quantizer
(
f0_input
,
1
)
outs
+=
z_p
.
detach
()
else
:
(
z_p
,
codes_p
,
latents_p
,
commitment_loss_p
,
codebook_loss_p
,
)
=
self
.
prosody_quantizer
(
x
,
1
)
outs
+=
z_p
.
detach
()
(
z_c
,
codes_c
,
latents_c
,
commitment_loss_c
,
codebook_loss_c
,
)
=
self
.
content_quantizer
(
x
,
n_c
)
outs
+=
z_c
.
detach
()
timbre_residual_feature
=
x
-
z_p
.
detach
()
-
z_c
.
detach
()
(
z_t
,
codes_t
,
latents_t
,
commitment_loss_t
,
codebook_loss_t
,
)
=
self
.
timbre_quantizer
(
timbre_residual_feature
,
2
)
outs
+=
z_t
# we should not detach timbre
residual_feature
=
timbre_residual_feature
-
z_t
(
z_r
,
codes_r
,
latents_r
,
commitment_loss_r
,
codebook_loss_r
,
)
=
self
.
residual_quantizer
(
residual_feature
,
3
)
return
[
codes_c
,
codes_p
,
codes_t
,
codes_r
],
[
z_c
,
z_p
,
z_t
,
z_r
]
def
forward
(
self
,
x
,
wave_segments
,
noise_added_flags
,
recon_noisy_flags
,
n_c
=
2
,
n_t
=
2
):
# timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
# timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
outs
=
0
if
self
.
separate_prosody_encoder
:
prosody_feature
=
self
.
preprocess
(
wave_segments
)
f0_input
=
prosody_feature
# (B, T, 20)
f0_input
=
self
.
melspec_linear
(
f0_input
)
f0_input
=
self
.
melspec_encoder
(
f0_input
,
torch
.
ones
(
f0_input
.
shape
[
0
],
1
,
f0_input
.
shape
[
2
])
.
to
(
f0_input
.
device
)
.
bool
(),
)
f0_input
=
self
.
melspec_linear2
(
f0_input
)
common_min_size
=
min
(
f0_input
.
size
(
2
),
x
.
size
(
2
))
f0_input
=
f0_input
[:,
:,
:
common_min_size
]
x
=
x
[:,
:,
:
common_min_size
]
(
z_p
,
codes_p
,
latents_p
,
commitment_loss_p
,
codebook_loss_p
,
)
=
self
.
prosody_quantizer
(
f0_input
,
1
)
outs
+=
z_p
.
detach
()
else
:
(
z_p
,
codes_p
,
latents_p
,
commitment_loss_p
,
codebook_loss_p
,
)
=
self
.
prosody_quantizer
(
x
,
1
)
outs
+=
z_p
.
detach
()
(
z_c
,
codes_c
,
latents_c
,
commitment_loss_c
,
codebook_loss_c
,
)
=
self
.
content_quantizer
(
x
,
n_c
)
outs
+=
z_c
.
detach
()
timbre_residual_feature
=
x
-
z_p
.
detach
()
-
z_c
.
detach
()
(
z_t
,
codes_t
,
latents_t
,
commitment_loss_t
,
codebook_loss_t
,
)
=
self
.
timbre_quantizer
(
timbre_residual_feature
,
n_t
)
outs
+=
z_t
# we should not detach timbre
residual_feature
=
timbre_residual_feature
-
z_t
(
z_r
,
codes_r
,
latents_r
,
commitment_loss_r
,
codebook_loss_r
,
)
=
self
.
residual_quantizer
(
residual_feature
,
3
)
bsz
=
z_r
.
shape
[
0
]
res_mask
=
np
.
random
.
choice
(
[
0
,
1
],
size
=
bsz
,
p
=
[
self
.
prob_random_mask_residual
,
1
-
self
.
prob_random_mask_residual
,
],
)
res_mask
=
torch
.
from_numpy
(
res_mask
).
unsqueeze
(
1
).
unsqueeze
(
1
)
# (B, 1, 1)
res_mask
=
res_mask
.
to
(
device
=
z_r
.
device
,
dtype
=
z_r
.
dtype
)
noise_must_on
=
noise_added_flags
*
recon_noisy_flags
noise_must_off
=
noise_added_flags
*
(
~
recon_noisy_flags
)
res_mask
[
noise_must_on
]
=
1
res_mask
[
noise_must_off
]
=
0
outs
+=
z_r
*
res_mask
quantized
=
[
z_p
,
z_c
,
z_t
,
z_r
]
commitment_losses
=
(
commitment_loss_p
+
commitment_loss_c
+
commitment_loss_t
+
commitment_loss_r
)
codebook_losses
=
(
codebook_loss_p
+
codebook_loss_c
+
codebook_loss_t
+
codebook_loss_r
)
return
outs
,
quantized
,
commitment_losses
,
codebook_losses
def
forward_v2
(
self
,
x
,
wave_segments
,
n_c
=
1
,
n_t
=
2
,
full_waves
=
None
,
wave_lens
=
None
,
return_codes
=
False
,
):
# timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
if
full_waves
is
None
:
mel
=
self
.
preprocess
(
wave_segments
,
n_bins
=
80
)
timbre
=
self
.
timbre_encoder
(
mel
,
torch
.
ones
(
mel
.
size
(
0
),
1
,
mel
.
size
(
2
)).
bool
().
to
(
mel
.
device
)
)
else
:
mel
=
self
.
preprocess
(
full_waves
,
n_bins
=
80
)
timbre
=
self
.
timbre_encoder
(
mel
,
sequence_mask
(
wave_lens
//
self
.
hop_length
,
mel
.
size
(
-
1
)).
unsqueeze
(
1
),
)
outs
=
0
if
self
.
separate_prosody_encoder
:
prosody_feature
=
self
.
preprocess
(
wave_segments
)
f0_input
=
prosody_feature
# (B, T, 20)
f0_input
=
self
.
melspec_linear
(
f0_input
)
f0_input
=
self
.
melspec_encoder
(
f0_input
,
torch
.
ones
(
f0_input
.
shape
[
0
],
1
,
f0_input
.
shape
[
2
])
.
to
(
f0_input
.
device
)
.
bool
(),
)
f0_input
=
self
.
melspec_linear2
(
f0_input
)
common_min_size
=
min
(
f0_input
.
size
(
2
),
x
.
size
(
2
))
f0_input
=
f0_input
[:,
:,
:
common_min_size
]
x
=
x
[:,
:,
:
common_min_size
]
(
z_p
,
codes_p
,
latents_p
,
commitment_loss_p
,
codebook_loss_p
,
)
=
self
.
prosody_quantizer
(
f0_input
,
1
)
outs
+=
z_p
.
detach
()
else
:
(
z_p
,
codes_p
,
latents_p
,
commitment_loss_p
,
codebook_loss_p
,
)
=
self
.
prosody_quantizer
(
x
,
1
)
outs
+=
z_p
.
detach
()
(
z_c
,
codes_c
,
latents_c
,
commitment_loss_c
,
codebook_loss_c
,
)
=
self
.
content_quantizer
(
x
,
n_c
)
outs
+=
z_c
.
detach
()
residual_feature
=
x
-
z_p
.
detach
()
-
z_c
.
detach
()
(
z_r
,
codes_r
,
latents_r
,
commitment_loss_r
,
codebook_loss_r
,
)
=
self
.
residual_quantizer
(
residual_feature
,
3
)
bsz
=
z_r
.
shape
[
0
]
res_mask
=
np
.
random
.
choice
(
[
0
,
1
],
size
=
bsz
,
p
=
[
self
.
prob_random_mask_residual
,
1
-
self
.
prob_random_mask_residual
,
],
)
res_mask
=
torch
.
from_numpy
(
res_mask
).
unsqueeze
(
1
).
unsqueeze
(
1
)
# (B, 1, 1)
res_mask
=
res_mask
.
to
(
device
=
z_r
.
device
,
dtype
=
z_r
.
dtype
)
if
not
self
.
training
:
res_mask
=
torch
.
ones_like
(
res_mask
)
outs
+=
z_r
*
res_mask
quantized
=
[
z_p
,
z_c
,
z_r
]
codes
=
[
codes_p
,
codes_c
,
codes_r
]
commitment_losses
=
commitment_loss_p
+
commitment_loss_c
+
commitment_loss_r
codebook_losses
=
codebook_loss_p
+
codebook_loss_c
+
codebook_loss_r
style
=
self
.
timbre_linear
(
timbre
).
unsqueeze
(
2
)
# (B, 2d, 1)
gamma
,
beta
=
style
.
chunk
(
2
,
1
)
# (B, d, 1)
outs
=
outs
.
transpose
(
1
,
2
)
outs
=
self
.
timbre_norm
(
outs
)
outs
=
outs
.
transpose
(
1
,
2
)
outs
=
outs
*
gamma
+
beta
if
return_codes
:
return
outs
,
quantized
,
commitment_losses
,
codebook_losses
,
timbre
,
codes
else
:
return
outs
,
quantized
,
commitment_losses
,
codebook_losses
,
timbre
def
voice_conversion
(
self
,
z
,
ref_wave
):
ref_mel
=
self
.
preprocess
(
ref_wave
,
n_bins
=
80
)
ref_timbre
=
self
.
timbre_encoder
(
ref_mel
,
sequence_mask
(
torch
.
LongTensor
([
ref_wave
.
size
(
-
1
)]).
to
(
z
.
device
)
//
self
.
hop_length
,
ref_mel
.
size
(
-
1
),
).
unsqueeze
(
1
),
)
style
=
self
.
timbre_linear
(
ref_timbre
).
unsqueeze
(
2
)
# (B, 2d, 1)
gamma
,
beta
=
style
.
chunk
(
2
,
1
)
# (B, d, 1)
outs
=
z
.
transpose
(
1
,
2
)
outs
=
self
.
timbre_norm
(
outs
)
outs
=
outs
.
transpose
(
1
,
2
)
outs
=
outs
*
gamma
+
beta
return
outs
class
FApredictors
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
=
1024
,
use_gr_content_f0
=
False
,
use_gr_prosody_phone
=
False
,
use_gr_residual_f0
=
False
,
use_gr_residual_phone
=
False
,
use_gr_timbre_content
=
True
,
use_gr_timbre_prosody
=
True
,
use_gr_x_timbre
=
False
,
norm_f0
=
True
,
timbre_norm
=
False
,
use_gr_content_global_f0
=
False
,
):
super
(
FApredictors
,
self
).
__init__
()
self
.
f0_predictor
=
CNNLSTM
(
in_dim
,
1
,
2
)
self
.
phone_predictor
=
CNNLSTM
(
in_dim
,
1024
,
1
)
if
timbre_norm
:
self
.
timbre_predictor
=
nn
.
Linear
(
in_dim
,
20000
)
else
:
self
.
timbre_predictor
=
CNNLSTM
(
in_dim
,
20000
,
1
,
global_pred
=
True
)
self
.
use_gr_content_f0
=
use_gr_content_f0
self
.
use_gr_prosody_phone
=
use_gr_prosody_phone
self
.
use_gr_residual_f0
=
use_gr_residual_f0
self
.
use_gr_residual_phone
=
use_gr_residual_phone
self
.
use_gr_timbre_content
=
use_gr_timbre_content
self
.
use_gr_timbre_prosody
=
use_gr_timbre_prosody
self
.
use_gr_x_timbre
=
use_gr_x_timbre
self
.
rev_f0_predictor
=
nn
.
Sequential
(
GradientReversal
(
alpha
=
1.0
),
CNNLSTM
(
in_dim
,
1
,
2
)
)
self
.
rev_content_predictor
=
nn
.
Sequential
(
GradientReversal
(
alpha
=
1.0
),
CNNLSTM
(
in_dim
,
1024
,
1
)
)
self
.
rev_timbre_predictor
=
nn
.
Sequential
(
GradientReversal
(
alpha
=
1.0
),
CNNLSTM
(
in_dim
,
20000
,
1
,
global_pred
=
True
)
)
self
.
norm_f0
=
norm_f0
self
.
timbre_norm
=
timbre_norm
if
timbre_norm
:
self
.
forward
=
self
.
forward_v2
self
.
global_f0_predictor
=
nn
.
Linear
(
in_dim
,
1
)
self
.
use_gr_content_global_f0
=
use_gr_content_global_f0
if
use_gr_content_global_f0
:
self
.
rev_global_f0_predictor
=
nn
.
Sequential
(
GradientReversal
(
alpha
=
1.0
),
CNNLSTM
(
in_dim
,
1
,
1
,
global_pred
=
True
)
)
def
forward
(
self
,
quantized
):
prosody_latent
=
quantized
[
0
]
content_latent
=
quantized
[
1
]
timbre_latent
=
quantized
[
2
]
residual_latent
=
quantized
[
3
]
content_pred
=
self
.
phone_predictor
(
content_latent
)[
0
]
if
self
.
norm_f0
:
spk_pred
=
self
.
timbre_predictor
(
timbre_latent
)[
0
]
f0_pred
,
uv_pred
=
self
.
f0_predictor
(
prosody_latent
)
else
:
spk_pred
=
self
.
timbre_predictor
(
timbre_latent
+
prosody_latent
)[
0
]
f0_pred
,
uv_pred
=
self
.
f0_predictor
(
prosody_latent
+
timbre_latent
)
prosody_rev_latent
=
torch
.
zeros_like
(
quantized
[
0
])
if
self
.
use_gr_content_f0
:
prosody_rev_latent
+=
quantized
[
1
]
if
self
.
use_gr_timbre_prosody
:
prosody_rev_latent
+=
quantized
[
2
]
if
self
.
use_gr_residual_f0
:
prosody_rev_latent
+=
quantized
[
3
]
rev_f0_pred
,
rev_uv_pred
=
self
.
rev_f0_predictor
(
prosody_rev_latent
)
content_rev_latent
=
torch
.
zeros_like
(
quantized
[
1
])
if
self
.
use_gr_prosody_phone
:
content_rev_latent
+=
quantized
[
0
]
if
self
.
use_gr_timbre_content
:
content_rev_latent
+=
quantized
[
2
]
if
self
.
use_gr_residual_phone
:
content_rev_latent
+=
quantized
[
3
]
rev_content_pred
=
self
.
rev_content_predictor
(
content_rev_latent
)[
0
]
if
self
.
norm_f0
:
timbre_rev_latent
=
quantized
[
0
]
+
quantized
[
1
]
+
quantized
[
3
]
else
:
timbre_rev_latent
=
quantized
[
1
]
+
quantized
[
3
]
if
self
.
use_gr_x_timbre
:
x_spk_pred
=
self
.
rev_timbre_predictor
(
timbre_rev_latent
)[
0
]
else
:
x_spk_pred
=
None
preds
=
{
"f0"
:
f0_pred
,
"uv"
:
uv_pred
,
"content"
:
content_pred
,
"timbre"
:
spk_pred
,
}
rev_preds
=
{
"rev_f0"
:
rev_f0_pred
,
"rev_uv"
:
rev_uv_pred
,
"rev_content"
:
rev_content_pred
,
"x_timbre"
:
x_spk_pred
,
}
return
preds
,
rev_preds
def
forward_v2
(
self
,
quantized
,
timbre
):
prosody_latent
=
quantized
[
0
]
content_latent
=
quantized
[
1
]
residual_latent
=
quantized
[
2
]
content_pred
=
self
.
phone_predictor
(
content_latent
)[
0
]
spk_pred
=
self
.
timbre_predictor
(
timbre
)
f0_pred
,
uv_pred
=
self
.
f0_predictor
(
prosody_latent
)
prosody_rev_latent
=
torch
.
zeros_like
(
prosody_latent
)
if
self
.
use_gr_content_f0
:
prosody_rev_latent
+=
content_latent
if
self
.
use_gr_residual_f0
:
prosody_rev_latent
+=
residual_latent
rev_f0_pred
,
rev_uv_pred
=
self
.
rev_f0_predictor
(
prosody_rev_latent
)
content_rev_latent
=
torch
.
zeros_like
(
content_latent
)
if
self
.
use_gr_prosody_phone
:
content_rev_latent
+=
prosody_latent
if
self
.
use_gr_residual_phone
:
content_rev_latent
+=
residual_latent
rev_content_pred
=
self
.
rev_content_predictor
(
content_rev_latent
)[
0
]
timbre_rev_latent
=
prosody_latent
+
content_latent
+
residual_latent
if
self
.
use_gr_x_timbre
:
x_spk_pred
=
self
.
rev_timbre_predictor
(
timbre_rev_latent
)[
0
]
else
:
x_spk_pred
=
None
preds
=
{
"f0"
:
f0_pred
,
"uv"
:
uv_pred
,
"content"
:
content_pred
,
"timbre"
:
spk_pred
,
}
rev_preds
=
{
"rev_f0"
:
rev_f0_pred
,
"rev_uv"
:
rev_uv_pred
,
"rev_content"
:
rev_content_pred
,
"x_timbre"
:
x_spk_pred
,
}
return
preds
,
rev_preds
indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/styleencoder.py
from
.
import
attentions
from
torch
import
nn
import
torch
from
torch.nn
import
functional
as
F
class
Mish
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Mish
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
*
torch
.
tanh
(
F
.
softplus
(
x
))
class
Conv1dGLU
(
nn
.
Module
):
"""
Conv1d + GLU(Gated Linear Unit) with residual connection.
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
dropout
):
super
(
Conv1dGLU
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
conv1
=
nn
.
Conv1d
(
in_channels
,
2
*
out_channels
,
kernel_size
=
kernel_size
,
padding
=
2
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
residual
=
x
x
=
self
.
conv1
(
x
)
x1
,
x2
=
torch
.
split
(
x
,
split_size_or_sections
=
self
.
out_channels
,
dim
=
1
)
x
=
x1
*
torch
.
sigmoid
(
x2
)
x
=
residual
+
self
.
dropout
(
x
)
return
x
class
StyleEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
=
513
,
hidden_dim
=
128
,
out_dim
=
256
):
super
().
__init__
()
self
.
in_dim
=
in_dim
# Linear 513 wav2vec 2.0 1024
self
.
hidden_dim
=
hidden_dim
self
.
out_dim
=
out_dim
self
.
kernel_size
=
5
self
.
n_head
=
2
self
.
dropout
=
0.1
self
.
spectral
=
nn
.
Sequential
(
nn
.
Conv1d
(
self
.
in_dim
,
self
.
hidden_dim
,
1
),
Mish
(),
nn
.
Dropout
(
self
.
dropout
),
nn
.
Conv1d
(
self
.
hidden_dim
,
self
.
hidden_dim
,
1
),
Mish
(),
nn
.
Dropout
(
self
.
dropout
),
)
self
.
temporal
=
nn
.
Sequential
(
Conv1dGLU
(
self
.
hidden_dim
,
self
.
hidden_dim
,
self
.
kernel_size
,
self
.
dropout
),
Conv1dGLU
(
self
.
hidden_dim
,
self
.
hidden_dim
,
self
.
kernel_size
,
self
.
dropout
),
)
self
.
slf_attn
=
attentions
.
MultiHeadAttention
(
self
.
hidden_dim
,
self
.
hidden_dim
,
self
.
n_head
,
p_dropout
=
self
.
dropout
,
proximal_bias
=
False
,
proximal_init
=
True
,
)
self
.
atten_drop
=
nn
.
Dropout
(
self
.
dropout
)
self
.
fc
=
nn
.
Conv1d
(
self
.
hidden_dim
,
self
.
out_dim
,
1
)
def
forward
(
self
,
x
,
mask
=
None
):
# spectral
x
=
self
.
spectral
(
x
)
*
mask
# temporal
x
=
self
.
temporal
(
x
)
*
mask
# self-attention
attn_mask
=
mask
.
unsqueeze
(
2
)
*
mask
.
unsqueeze
(
-
1
)
y
=
self
.
slf_attn
(
x
,
x
,
attn_mask
=
attn_mask
)
x
=
x
+
self
.
atten_drop
(
y
)
# fc
x
=
self
.
fc
(
x
)
# temoral average pooling
w
=
self
.
temporal_avg_pool
(
x
,
mask
=
mask
)
return
w
def
temporal_avg_pool
(
self
,
x
,
mask
=
None
):
if
mask
is
None
:
out
=
torch
.
mean
(
x
,
dim
=
2
)
else
:
len_
=
mask
.
sum
(
dim
=
2
)
x
=
x
.
sum
(
dim
=
2
)
out
=
torch
.
div
(
x
,
len_
)
return
out
indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py
0 → 100644
View file @
ab9c00af
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# This code is modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/ttv_v1/modules.py
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
modules.dac.model.encodec
import
SConv1d
from
.
import
commons
LRELU_SLOPE
=
0.1
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
channels
,
eps
=
1e-5
):
super
().
__init__
()
self
.
channels
=
channels
self
.
eps
=
eps
self
.
gamma
=
nn
.
Parameter
(
torch
.
ones
(
channels
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
channels
))
def
forward
(
self
,
x
):
x
=
x
.
transpose
(
1
,
-
1
)
x
=
F
.
layer_norm
(
x
,
(
self
.
channels
,),
self
.
gamma
,
self
.
beta
,
self
.
eps
)
return
x
.
transpose
(
1
,
-
1
)
class
ConvReluNorm
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
hidden_channels
,
out_channels
,
kernel_size
,
n_layers
,
p_dropout
,
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
hidden_channels
=
hidden_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
n_layers
=
n_layers
self
.
p_dropout
=
p_dropout
assert
n_layers
>
1
,
"Number of layers should be larger than 0."
self
.
conv_layers
=
nn
.
ModuleList
()
self
.
norm_layers
=
nn
.
ModuleList
()
self
.
conv_layers
.
append
(
nn
.
Conv1d
(
in_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
)
)
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
relu_drop
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Dropout
(
p_dropout
))
for
_
in
range
(
n_layers
-
1
):
self
.
conv_layers
.
append
(
nn
.
Conv1d
(
hidden_channels
,
hidden_channels
,
kernel_size
,
padding
=
kernel_size
//
2
,
)
)
self
.
norm_layers
.
append
(
LayerNorm
(
hidden_channels
))
self
.
proj
=
nn
.
Conv1d
(
hidden_channels
,
out_channels
,
1
)
self
.
proj
.
weight
.
data
.
zero_
()
self
.
proj
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
x_mask
):
x_org
=
x
for
i
in
range
(
self
.
n_layers
):
x
=
self
.
conv_layers
[
i
](
x
*
x_mask
)
x
=
self
.
norm_layers
[
i
](
x
)
x
=
self
.
relu_drop
(
x
)
x
=
x_org
+
self
.
proj
(
x
)
return
x
*
x_mask
class
DDSConv
(
nn
.
Module
):
"""
Dialted and Depth-Separable Convolution
"""
def
__init__
(
self
,
channels
,
kernel_size
,
n_layers
,
p_dropout
=
0.0
):
super
().
__init__
()
self
.
channels
=
channels
self
.
kernel_size
=
kernel_size
self
.
n_layers
=
n_layers
self
.
p_dropout
=
p_dropout
self
.
drop
=
nn
.
Dropout
(
p_dropout
)
self
.
convs_sep
=
nn
.
ModuleList
()
self
.
convs_1x1
=
nn
.
ModuleList
()
self
.
norms_1
=
nn
.
ModuleList
()
self
.
norms_2
=
nn
.
ModuleList
()
for
i
in
range
(
n_layers
):
dilation
=
kernel_size
**
i
padding
=
(
kernel_size
*
dilation
-
dilation
)
//
2
self
.
convs_sep
.
append
(
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
,
groups
=
channels
,
dilation
=
dilation
,
padding
=
padding
,
)
)
self
.
convs_1x1
.
append
(
nn
.
Conv1d
(
channels
,
channels
,
1
))
self
.
norms_1
.
append
(
LayerNorm
(
channels
))
self
.
norms_2
.
append
(
LayerNorm
(
channels
))
def
forward
(
self
,
x
,
x_mask
,
g
=
None
):
if
g
is
not
None
:
x
=
x
+
g
for
i
in
range
(
self
.
n_layers
):
y
=
self
.
convs_sep
[
i
](
x
*
x_mask
)
y
=
self
.
norms_1
[
i
](
y
)
y
=
F
.
gelu
(
y
)
y
=
self
.
convs_1x1
[
i
](
y
)
y
=
self
.
norms_2
[
i
](
y
)
y
=
F
.
gelu
(
y
)
y
=
self
.
drop
(
y
)
x
=
x
+
y
return
x
*
x_mask
class
WN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_channels
,
kernel_size
,
dilation_rate
,
n_layers
,
gin_channels
=
0
,
p_dropout
=
0
,
causal
=
False
,
):
super
(
WN
,
self
).
__init__
()
conv1d_type
=
SConv1d
assert
kernel_size
%
2
==
1
self
.
hidden_channels
=
hidden_channels
self
.
kernel_size
=
(
kernel_size
,)
self
.
dilation_rate
=
dilation_rate
self
.
n_layers
=
n_layers
self
.
gin_channels
=
gin_channels
self
.
p_dropout
=
p_dropout
self
.
in_layers
=
torch
.
nn
.
ModuleList
()
self
.
res_skip_layers
=
torch
.
nn
.
ModuleList
()
self
.
drop
=
nn
.
Dropout
(
p_dropout
)
if
gin_channels
!=
0
:
self
.
cond_layer
=
conv1d_type
(
gin_channels
,
2
*
hidden_channels
*
n_layers
,
1
,
norm
=
"weight_norm"
)
for
i
in
range
(
n_layers
):
dilation
=
dilation_rate
**
i
padding
=
int
((
kernel_size
*
dilation
-
dilation
)
/
2
)
in_layer
=
conv1d_type
(
hidden_channels
,
2
*
hidden_channels
,
kernel_size
,
dilation
=
dilation
,
padding
=
padding
,
norm
=
"weight_norm"
,
causal
=
causal
,
)
self
.
in_layers
.
append
(
in_layer
)
# last one is not necessary
if
i
<
n_layers
-
1
:
res_skip_channels
=
2
*
hidden_channels
else
:
res_skip_channels
=
hidden_channels
res_skip_layer
=
conv1d_type
(
hidden_channels
,
res_skip_channels
,
1
,
norm
=
"weight_norm"
,
causal
=
causal
)
self
.
res_skip_layers
.
append
(
res_skip_layer
)
def
forward
(
self
,
x
,
x_mask
,
g
=
None
,
**
kwargs
):
output
=
torch
.
zeros_like
(
x
)
n_channels_tensor
=
torch
.
IntTensor
([
self
.
hidden_channels
])
if
g
is
not
None
:
g
=
self
.
cond_layer
(
g
)
for
i
in
range
(
self
.
n_layers
):
x_in
=
self
.
in_layers
[
i
](
x
)
if
g
is
not
None
:
cond_offset
=
i
*
2
*
self
.
hidden_channels
g_l
=
g
[:,
cond_offset
:
cond_offset
+
2
*
self
.
hidden_channels
,
:]
else
:
g_l
=
torch
.
zeros_like
(
x_in
)
acts
=
commons
.
fused_add_tanh_sigmoid_multiply
(
x_in
,
g_l
,
n_channels_tensor
)
acts
=
self
.
drop
(
acts
)
res_skip_acts
=
self
.
res_skip_layers
[
i
](
acts
)
if
i
<
self
.
n_layers
-
1
:
res_acts
=
res_skip_acts
[:,
:
self
.
hidden_channels
,
:]
x
=
(
x
+
res_acts
)
*
x_mask
output
=
output
+
res_skip_acts
[:,
self
.
hidden_channels
:,
:]
else
:
output
=
output
+
res_skip_acts
return
output
*
x_mask
def
remove_weight_norm
(
self
):
if
self
.
gin_channels
!=
0
:
torch
.
nn
.
utils
.
remove_weight_norm
(
self
.
cond_layer
)
for
l
in
self
.
in_layers
:
torch
.
nn
.
utils
.
remove_weight_norm
(
l
)
for
l
in
self
.
res_skip_layers
:
torch
.
nn
.
utils
.
remove_weight_norm
(
l
)
Prev
1
…
9
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment