Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
InspireMusic_pytorch
Commits
0112b0f0
Commit
0112b0f0
authored
Feb 14, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2394
canceled with stages
Changes
474
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1544 additions
and
0 deletions
+1544
-0
third_party/Matcha-TTS/matcha/onnx/infer.py
third_party/Matcha-TTS/matcha/onnx/infer.py
+168
-0
third_party/Matcha-TTS/matcha/text/__init__.py
third_party/Matcha-TTS/matcha/text/__init__.py
+57
-0
third_party/Matcha-TTS/matcha/text/cleaners.py
third_party/Matcha-TTS/matcha/text/cleaners.py
+144
-0
third_party/Matcha-TTS/matcha/text/numbers.py
third_party/Matcha-TTS/matcha/text/numbers.py
+71
-0
third_party/Matcha-TTS/matcha/text/symbols.py
third_party/Matcha-TTS/matcha/text/symbols.py
+17
-0
third_party/Matcha-TTS/matcha/train.py
third_party/Matcha-TTS/matcha/train.py
+122
-0
third_party/Matcha-TTS/matcha/utils/__init__.py
third_party/Matcha-TTS/matcha/utils/__init__.py
+5
-0
third_party/Matcha-TTS/matcha/utils/audio.py
third_party/Matcha-TTS/matcha/utils/audio.py
+82
-0
third_party/Matcha-TTS/matcha/utils/data/__init__.py
third_party/Matcha-TTS/matcha/utils/data/__init__.py
+0
-0
third_party/Matcha-TTS/matcha/utils/data/hificaptain.py
third_party/Matcha-TTS/matcha/utils/data/hificaptain.py
+148
-0
third_party/Matcha-TTS/matcha/utils/data/ljspeech.py
third_party/Matcha-TTS/matcha/utils/data/ljspeech.py
+97
-0
third_party/Matcha-TTS/matcha/utils/data/utils.py
third_party/Matcha-TTS/matcha/utils/data/utils.py
+53
-0
third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py
...party/Matcha-TTS/matcha/utils/generate_data_statistics.py
+110
-0
third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py
...tcha-TTS/matcha/utils/get_durations_from_trained_model.py
+195
-0
third_party/Matcha-TTS/matcha/utils/instantiators.py
third_party/Matcha-TTS/matcha/utils/instantiators.py
+56
-0
third_party/Matcha-TTS/matcha/utils/logging_utils.py
third_party/Matcha-TTS/matcha/utils/logging_utils.py
+53
-0
third_party/Matcha-TTS/matcha/utils/model.py
third_party/Matcha-TTS/matcha/utils/model.py
+90
-0
third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
...party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
+22
-0
third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx
third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx
+47
-0
third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py
third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py
+7
-0
No files found.
third_party/Matcha-TTS/matcha/onnx/infer.py
0 → 100644
View file @
0112b0f0
import
argparse
import
os
import
warnings
from
pathlib
import
Path
from
time
import
perf_counter
import
numpy
as
np
import
onnxruntime
as
ort
import
soundfile
as
sf
import
torch
from
matcha.cli
import
plot_spectrogram_to_numpy
,
process_text
def
validate_args
(
args
):
assert
(
args
.
text
or
args
.
file
),
"Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert
args
.
temperature
>=
0
,
"Sampling temperature cannot be negative"
assert
args
.
speaking_rate
>=
0
,
"Speaking rate must be greater than 0"
return
args
def
write_wavs
(
model
,
inputs
,
output_dir
,
external_vocoder
=
None
):
if
external_vocoder
is
None
:
print
(
"The provided model has the vocoder embedded in the graph.
\n
Generating waveform directly"
)
t0
=
perf_counter
()
wavs
,
wav_lengths
=
model
.
run
(
None
,
inputs
)
infer_secs
=
perf_counter
()
-
t0
mel_infer_secs
=
vocoder_infer_secs
=
None
else
:
print
(
"[🍵] Generating mel using Matcha"
)
mel_t0
=
perf_counter
()
mels
,
mel_lengths
=
model
.
run
(
None
,
inputs
)
mel_infer_secs
=
perf_counter
()
-
mel_t0
print
(
"Generating waveform from mel using external vocoder"
)
vocoder_inputs
=
{
external_vocoder
.
get_inputs
()[
0
].
name
:
mels
}
vocoder_t0
=
perf_counter
()
wavs
=
external_vocoder
.
run
(
None
,
vocoder_inputs
)[
0
]
vocoder_infer_secs
=
perf_counter
()
-
vocoder_t0
wavs
=
wavs
.
squeeze
(
1
)
wav_lengths
=
mel_lengths
*
256
infer_secs
=
mel_infer_secs
+
vocoder_infer_secs
output_dir
=
Path
(
output_dir
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
for
i
,
(
wav
,
wav_length
)
in
enumerate
(
zip
(
wavs
,
wav_lengths
)):
output_filename
=
output_dir
.
joinpath
(
f
"output_
{
i
+
1
}
.wav"
)
audio
=
wav
[:
wav_length
]
print
(
f
"Writing audio to
{
output_filename
}
"
)
sf
.
write
(
output_filename
,
audio
,
22050
,
"PCM_24"
)
wav_secs
=
wav_lengths
.
sum
()
/
22050
print
(
f
"Inference seconds:
{
infer_secs
}
"
)
print
(
f
"Generated wav seconds:
{
wav_secs
}
"
)
rtf
=
infer_secs
/
wav_secs
if
mel_infer_secs
is
not
None
:
mel_rtf
=
mel_infer_secs
/
wav_secs
print
(
f
"Matcha RTF:
{
mel_rtf
}
"
)
if
vocoder_infer_secs
is
not
None
:
vocoder_rtf
=
vocoder_infer_secs
/
wav_secs
print
(
f
"Vocoder RTF:
{
vocoder_rtf
}
"
)
print
(
f
"Overall RTF:
{
rtf
}
"
)
def
write_mels
(
model
,
inputs
,
output_dir
):
t0
=
perf_counter
()
mels
,
mel_lengths
=
model
.
run
(
None
,
inputs
)
infer_secs
=
perf_counter
()
-
t0
output_dir
=
Path
(
output_dir
)
output_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
for
i
,
mel
in
enumerate
(
mels
):
output_stem
=
output_dir
.
joinpath
(
f
"output_
{
i
+
1
}
"
)
plot_spectrogram_to_numpy
(
mel
.
squeeze
(),
output_stem
.
with_suffix
(
".png"
))
np
.
save
(
output_stem
.
with_suffix
(
".numpy"
),
mel
)
wav_secs
=
(
mel_lengths
*
256
).
sum
()
/
22050
print
(
f
"Inference seconds:
{
infer_secs
}
"
)
print
(
f
"Generated wav seconds:
{
wav_secs
}
"
)
rtf
=
infer_secs
/
wav_secs
print
(
f
"RTF:
{
rtf
}
"
)
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
)
parser
.
add_argument
(
"model"
,
type
=
str
,
help
=
"ONNX model to use"
,
)
parser
.
add_argument
(
"--vocoder"
,
type
=
str
,
default
=
None
,
help
=
"Vocoder to use (defaults to None)"
)
parser
.
add_argument
(
"--text"
,
type
=
str
,
default
=
None
,
help
=
"Text to synthesize"
)
parser
.
add_argument
(
"--file"
,
type
=
str
,
default
=
None
,
help
=
"Text file to synthesize"
)
parser
.
add_argument
(
"--spk"
,
type
=
int
,
default
=
None
,
help
=
"Speaker ID"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
0.667
,
help
=
"Variance of the x0 noise (default: 0.667)"
,
)
parser
.
add_argument
(
"--speaking-rate"
,
type
=
float
,
default
=
1.0
,
help
=
"change the speaking rate, a higher value means slower speaking rate (default: 1.0)"
,
)
parser
.
add_argument
(
"--gpu"
,
action
=
"store_true"
,
help
=
"Use CPU for inference (default: use GPU if available)"
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
default
=
os
.
getcwd
(),
help
=
"Output folder to save results (default: current dir)"
,
)
args
=
parser
.
parse_args
()
args
=
validate_args
(
args
)
if
args
.
gpu
:
providers
=
[
"GPUExecutionProvider"
]
else
:
providers
=
[
"CPUExecutionProvider"
]
model
=
ort
.
InferenceSession
(
args
.
model
,
providers
=
providers
)
model_inputs
=
model
.
get_inputs
()
model_outputs
=
list
(
model
.
get_outputs
())
if
args
.
text
:
text_lines
=
args
.
text
.
splitlines
()
else
:
with
open
(
args
.
file
,
encoding
=
"utf-8"
)
as
file
:
text_lines
=
file
.
read
().
splitlines
()
processed_lines
=
[
process_text
(
0
,
line
,
"cpu"
)
for
line
in
text_lines
]
x
=
[
line
[
"x"
].
squeeze
()
for
line
in
processed_lines
]
# Pad
x
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
x
,
batch_first
=
True
)
x
=
x
.
detach
().
cpu
().
numpy
()
x_lengths
=
np
.
array
([
line
[
"x_lengths"
].
item
()
for
line
in
processed_lines
],
dtype
=
np
.
int64
)
inputs
=
{
"x"
:
x
,
"x_lengths"
:
x_lengths
,
"scales"
:
np
.
array
([
args
.
temperature
,
args
.
speaking_rate
],
dtype
=
np
.
float32
),
}
is_multi_speaker
=
len
(
model_inputs
)
==
4
if
is_multi_speaker
:
if
args
.
spk
is
None
:
args
.
spk
=
0
warn
=
"[!] Speaker ID not provided! Using speaker ID 0"
warnings
.
warn
(
warn
,
UserWarning
)
inputs
[
"spks"
]
=
np
.
repeat
(
args
.
spk
,
x
.
shape
[
0
]).
astype
(
np
.
int64
)
has_vocoder_embedded
=
model_outputs
[
0
].
name
==
"wav"
if
has_vocoder_embedded
:
write_wavs
(
model
,
inputs
,
args
.
output_dir
)
elif
args
.
vocoder
:
external_vocoder
=
ort
.
InferenceSession
(
args
.
vocoder
,
providers
=
providers
)
write_wavs
(
model
,
inputs
,
args
.
output_dir
,
external_vocoder
=
external_vocoder
)
else
:
warn
=
"[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
warnings
.
warn
(
warn
,
UserWarning
)
write_mels
(
model
,
inputs
,
args
.
output_dir
)
if
__name__
==
"__main__"
:
main
()
third_party/Matcha-TTS/matcha/text/__init__.py
0 → 100644
View file @
0112b0f0
""" from https://github.com/keithito/tacotron """
from
matcha.text
import
cleaners
from
matcha.text.symbols
import
symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
# pylint: disable=unnecessary-comprehension
class
UnknownCleanerException
(
Exception
):
pass
def
text_to_sequence
(
text
,
cleaner_names
):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
"""
sequence
=
[]
clean_text
=
_clean_text
(
text
,
cleaner_names
)
for
symbol
in
clean_text
:
symbol_id
=
_symbol_to_id
[
symbol
]
sequence
+=
[
symbol_id
]
return
sequence
,
clean_text
def
cleaned_text_to_sequence
(
cleaned_text
):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
"""
sequence
=
[
_symbol_to_id
[
symbol
]
for
symbol
in
cleaned_text
]
return
sequence
def
sequence_to_text
(
sequence
):
"""Converts a sequence of IDs back to a string"""
result
=
""
for
symbol_id
in
sequence
:
s
=
_id_to_symbol
[
symbol_id
]
result
+=
s
return
result
def
_clean_text
(
text
,
cleaner_names
):
for
name
in
cleaner_names
:
cleaner
=
getattr
(
cleaners
,
name
)
if
not
cleaner
:
raise
UnknownCleanerException
(
f
"Unknown cleaner:
{
name
}
"
)
text
=
cleaner
(
text
)
return
text
third_party/Matcha-TTS/matcha/text/cleaners.py
0 → 100644
View file @
0112b0f0
""" from https://github.com/keithito/tacotron
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
"""
import
logging
import
re
import
phonemizer
from
unidecode
import
unidecode
# To avoid excessive logging we set the log level of the phonemizer package to Critical
critical_logger
=
logging
.
getLogger
(
"phonemizer"
)
critical_logger
.
setLevel
(
logging
.
CRITICAL
)
# Intializing the phonemizer globally significantly reduces the speed
# now the phonemizer is not initialising at every call
# Might be less flexible, but it is much-much faster
global_phonemizer
=
phonemizer
.
backend
.
EspeakBackend
(
language
=
"en-us"
,
preserve_punctuation
=
True
,
with_stress
=
True
,
language_switch
=
"remove-flags"
,
logger
=
critical_logger
,
)
# Regular expression matching whitespace:
_whitespace_re
=
re
.
compile
(
r
"\s+"
)
# Remove brackets
_brackets_re
=
re
.
compile
(
r
"[\[\]\(\)\{\}]"
)
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations
=
[
(
re
.
compile
(
f
"
\\
b
{
x
[
0
]
}
\\
."
,
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
"mrs"
,
"misess"
),
(
"mr"
,
"mister"
),
(
"dr"
,
"doctor"
),
(
"st"
,
"saint"
),
(
"co"
,
"company"
),
(
"jr"
,
"junior"
),
(
"maj"
,
"major"
),
(
"gen"
,
"general"
),
(
"drs"
,
"doctors"
),
(
"rev"
,
"reverend"
),
(
"lt"
,
"lieutenant"
),
(
"hon"
,
"honorable"
),
(
"sgt"
,
"sergeant"
),
(
"capt"
,
"captain"
),
(
"esq"
,
"esquire"
),
(
"ltd"
,
"limited"
),
(
"col"
,
"colonel"
),
(
"ft"
,
"fort"
),
]
]
def
expand_abbreviations
(
text
):
for
regex
,
replacement
in
_abbreviations
:
text
=
re
.
sub
(
regex
,
replacement
,
text
)
return
text
def
lowercase
(
text
):
return
text
.
lower
()
def
remove_brackets
(
text
):
return
re
.
sub
(
_brackets_re
,
""
,
text
)
def
collapse_whitespace
(
text
):
return
re
.
sub
(
_whitespace_re
,
" "
,
text
)
def
convert_to_ascii
(
text
):
return
unidecode
(
text
)
def
basic_cleaners
(
text
):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
transliteration_cleaners
(
text
):
"""Pipeline for non-English text that transliterates to ASCII."""
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
english_cleaners2
(
text
):
"""Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
expand_abbreviations
(
text
)
phonemes
=
global_phonemizer
.
phonemize
([
text
],
strip
=
True
,
njobs
=
1
)[
0
]
# Added in some cases espeak is not removing brackets
phonemes
=
remove_brackets
(
phonemes
)
phonemes
=
collapse_whitespace
(
phonemes
)
return
phonemes
def
ipa_simplifier
(
text
):
replacements
=
[
(
"ɐ"
,
"ə"
),
(
"ˈə"
,
"ə"
),
(
"ʤ"
,
"dʒ"
),
(
"ʧ"
,
"tʃ"
),
(
"ᵻ"
,
"ɪ"
),
]
for
replacement
in
replacements
:
text
=
text
.
replace
(
replacement
[
0
],
replacement
[
1
])
phonemes
=
collapse_whitespace
(
text
)
return
phonemes
# I am removing this due to incompatibility with several version of python
# However, if you want to use it, you can uncomment it
# and install piper-phonemize with the following command:
# pip install piper-phonemize
# import piper_phonemize
# def english_cleaners_piper(text):
# """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
# text = convert_to_ascii(text)
# text = lowercase(text)
# text = expand_abbreviations(text)
# phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
# phonemes = collapse_whitespace(phonemes)
# return phonemes
third_party/Matcha-TTS/matcha/text/numbers.py
0 → 100644
View file @
0112b0f0
""" from https://github.com/keithito/tacotron """
import
re
import
inflect
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
"([0-9][0-9\,]+[0-9])"
)
_decimal_number_re
=
re
.
compile
(
r
"([0-9]+\.[0-9]+)"
)
_pounds_re
=
re
.
compile
(
r
"£([0-9\,]*[0-9]+)"
)
_dollars_re
=
re
.
compile
(
r
"\$([0-9\.\,]*[0-9]+)"
)
_ordinal_re
=
re
.
compile
(
r
"[0-9]+(st|nd|rd|th)"
)
_number_re
=
re
.
compile
(
r
"[0-9]+"
)
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
","
,
""
)
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
"."
,
" point "
)
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
"."
)
if
len
(
parts
)
>
2
:
return
match
+
" dollars"
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
f
"
{
dollars
}
{
dollar_unit
}
,
{
cents
}
{
cent_unit
}
"
elif
dollars
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
return
f
"
{
dollars
}
{
dollar_unit
}
"
elif
cents
:
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
f
"
{
cents
}
{
cent_unit
}
"
else
:
return
"zero dollars"
def
_expand_ordinal
(
m
):
return
_inflect
.
number_to_words
(
m
.
group
(
0
))
def
_expand_number
(
m
):
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
return
"two thousand"
elif
num
>
2000
and
num
<
2010
:
return
"two thousand "
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
" hundred"
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
,
zero
=
"oh"
,
group
=
2
).
replace
(
", "
,
" "
)
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
)
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
"\1 pounds"
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
third_party/Matcha-TTS/matcha/text/symbols.py
0 → 100644
View file @
0112b0f0
""" from https://github.com/keithito/tacotron
Defines the set of symbols used in text input to the model.
"""
_pad
=
"_"
_punctuation
=
';:,.!?¡¿—…"«»“” '
_letters
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa
=
(
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
)
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_punctuation
)
+
list
(
_letters
)
+
list
(
_letters_ipa
)
# Special symbol ids
SPACE_ID
=
symbols
.
index
(
" "
)
third_party/Matcha-TTS/matcha/train.py
0 → 100644
View file @
0112b0f0
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
hydra
import
lightning
as
L
import
rootutils
from
lightning
import
Callback
,
LightningDataModule
,
LightningModule
,
Trainer
from
lightning.pytorch.loggers
import
Logger
from
omegaconf
import
DictConfig
from
matcha
import
utils
rootutils
.
setup_root
(
__file__
,
indicator
=
".project-root"
,
pythonpath
=
True
)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
log
=
utils
.
get_pylogger
(
__name__
)
@
utils
.
task_wrapper
def
train
(
cfg
:
DictConfig
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if
cfg
.
get
(
"seed"
):
L
.
seed_everything
(
cfg
.
seed
,
workers
=
True
)
log
.
info
(
f
"Instantiating datamodule <
{
cfg
.
data
.
_target_
}
>"
)
# pylint: disable=protected-access
datamodule
:
LightningDataModule
=
hydra
.
utils
.
instantiate
(
cfg
.
data
)
log
.
info
(
f
"Instantiating model <
{
cfg
.
model
.
_target_
}
>"
)
# pylint: disable=protected-access
model
:
LightningModule
=
hydra
.
utils
.
instantiate
(
cfg
.
model
)
log
.
info
(
"Instantiating callbacks..."
)
callbacks
:
List
[
Callback
]
=
utils
.
instantiate_callbacks
(
cfg
.
get
(
"callbacks"
))
log
.
info
(
"Instantiating loggers..."
)
logger
:
List
[
Logger
]
=
utils
.
instantiate_loggers
(
cfg
.
get
(
"logger"
))
log
.
info
(
f
"Instantiating trainer <
{
cfg
.
trainer
.
_target_
}
>"
)
# pylint: disable=protected-access
trainer
:
Trainer
=
hydra
.
utils
.
instantiate
(
cfg
.
trainer
,
callbacks
=
callbacks
,
logger
=
logger
)
object_dict
=
{
"cfg"
:
cfg
,
"datamodule"
:
datamodule
,
"model"
:
model
,
"callbacks"
:
callbacks
,
"logger"
:
logger
,
"trainer"
:
trainer
,
}
if
logger
:
log
.
info
(
"Logging hyperparameters!"
)
utils
.
log_hyperparameters
(
object_dict
)
if
cfg
.
get
(
"train"
):
log
.
info
(
"Starting training!"
)
trainer
.
fit
(
model
=
model
,
datamodule
=
datamodule
,
ckpt_path
=
cfg
.
get
(
"ckpt_path"
))
train_metrics
=
trainer
.
callback_metrics
if
cfg
.
get
(
"test"
):
log
.
info
(
"Starting testing!"
)
ckpt_path
=
trainer
.
checkpoint_callback
.
best_model_path
if
ckpt_path
==
""
:
log
.
warning
(
"Best ckpt not found! Using current weights for testing..."
)
ckpt_path
=
None
trainer
.
test
(
model
=
model
,
datamodule
=
datamodule
,
ckpt_path
=
ckpt_path
)
log
.
info
(
f
"Best ckpt path:
{
ckpt_path
}
"
)
test_metrics
=
trainer
.
callback_metrics
# merge train and test metrics
metric_dict
=
{
**
train_metrics
,
**
test_metrics
}
return
metric_dict
,
object_dict
@
hydra
.
main
(
version_base
=
"1.3"
,
config_path
=
"../configs"
,
config_name
=
"train.yaml"
)
def
main
(
cfg
:
DictConfig
)
->
Optional
[
float
]:
"""Main entry point for training.
:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils
.
extras
(
cfg
)
# train the model
metric_dict
,
_
=
train
(
cfg
)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value
=
utils
.
get_metric_value
(
metric_dict
=
metric_dict
,
metric_name
=
cfg
.
get
(
"optimized_metric"
))
# return optimized metric
return
metric_value
if
__name__
==
"__main__"
:
main
()
# pylint: disable=no-value-for-parameter
third_party/Matcha-TTS/matcha/utils/__init__.py
0 → 100644
View file @
0112b0f0
from
matcha.utils.instantiators
import
instantiate_callbacks
,
instantiate_loggers
from
matcha.utils.logging_utils
import
log_hyperparameters
from
matcha.utils.pylogger
import
get_pylogger
from
matcha.utils.rich_utils
import
enforce_tags
,
print_config_tree
from
matcha.utils.utils
import
extras
,
get_metric_value
,
task_wrapper
third_party/Matcha-TTS/matcha/utils/audio.py
0 → 100644
View file @
0112b0f0
import
numpy
as
np
import
torch
import
torch.utils.data
from
librosa.filters
import
mel
as
librosa_mel_fn
from
scipy.io.wavfile
import
read
MAX_WAV_VALUE
=
32768.0
def
load_wav
(
full_path
):
sampling_rate
,
data
=
read
(
full_path
)
return
data
,
sampling_rate
def
dynamic_range_compression
(
x
,
C
=
1
,
clip_val
=
1e-5
):
return
np
.
log
(
np
.
clip
(
x
,
a_min
=
clip_val
,
a_max
=
None
)
*
C
)
def
dynamic_range_decompression
(
x
,
C
=
1
):
return
np
.
exp
(
x
)
/
C
def
dynamic_range_compression_torch
(
x
,
C
=
1
,
clip_val
=
1e-5
):
return
torch
.
log
(
torch
.
clamp
(
x
,
min
=
clip_val
)
*
C
)
def
dynamic_range_decompression_torch
(
x
,
C
=
1
):
return
torch
.
exp
(
x
)
/
C
def
spectral_normalize_torch
(
magnitudes
):
output
=
dynamic_range_compression_torch
(
magnitudes
)
return
output
def
spectral_de_normalize_torch
(
magnitudes
):
output
=
dynamic_range_decompression_torch
(
magnitudes
)
return
output
mel_basis
=
{}
hann_window
=
{}
def
mel_spectrogram
(
y
,
n_fft
,
num_mels
,
sampling_rate
,
hop_size
,
win_size
,
fmin
,
fmax
,
center
=
False
):
if
torch
.
min
(
y
)
<
-
1.0
:
print
(
"min value is "
,
torch
.
min
(
y
))
if
torch
.
max
(
y
)
>
1.0
:
print
(
"max value is "
,
torch
.
max
(
y
))
global
mel_basis
,
hann_window
# pylint: disable=global-statement,global-variable-not-assigned
if
f
"
{
str
(
fmax
)
}
_
{
str
(
y
.
device
)
}
"
not
in
mel_basis
:
mel
=
librosa_mel_fn
(
sr
=
sampling_rate
,
n_fft
=
n_fft
,
n_mels
=
num_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
mel_basis
[
str
(
fmax
)
+
"_"
+
str
(
y
.
device
)]
=
torch
.
from_numpy
(
mel
).
float
().
to
(
y
.
device
)
hann_window
[
str
(
y
.
device
)]
=
torch
.
hann_window
(
win_size
).
to
(
y
.
device
)
y
=
torch
.
nn
.
functional
.
pad
(
y
.
unsqueeze
(
1
),
(
int
((
n_fft
-
hop_size
)
/
2
),
int
((
n_fft
-
hop_size
)
/
2
)),
mode
=
"reflect"
)
y
=
y
.
squeeze
(
1
)
spec
=
torch
.
view_as_real
(
torch
.
stft
(
y
,
n_fft
,
hop_length
=
hop_size
,
win_length
=
win_size
,
window
=
hann_window
[
str
(
y
.
device
)],
center
=
center
,
pad_mode
=
"reflect"
,
normalized
=
False
,
onesided
=
True
,
return_complex
=
True
,
)
)
spec
=
torch
.
sqrt
(
spec
.
pow
(
2
).
sum
(
-
1
)
+
(
1e-9
))
spec
=
torch
.
matmul
(
mel_basis
[
str
(
fmax
)
+
"_"
+
str
(
y
.
device
)],
spec
)
spec
=
spectral_normalize_torch
(
spec
)
return
spec
third_party/Matcha-TTS/matcha/utils/data/__init__.py
0 → 100644
View file @
0112b0f0
third_party/Matcha-TTS/matcha/utils/data/hificaptain.py
0 → 100644
View file @
0112b0f0
#!/usr/bin/env python
import
argparse
import
os
import
sys
import
tempfile
from
pathlib
import
Path
import
torchaudio
from
torch.hub
import
download_url_to_file
from
tqdm
import
tqdm
from
matcha.utils.data.utils
import
_extract_zip
URLS
=
{
"en-US"
:
{
"female"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip"
,
"male"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip"
,
},
"ja-JP"
:
{
"female"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip"
,
"male"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip"
,
},
}
INFO_PAGE
=
"https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/"
# On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN",
# but they use this very-much-not-open-source licence.
# Dunno if this is open washing or stupidity.
LICENCE
=
"CC BY-NC-SA 4.0"
# I'd normally put the citation here. It's on their website.
# Boo to non-open-source stuff.
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-s"
,
"--save-dir"
,
type
=
str
,
default
=
None
,
help
=
"Place to store the downloaded zip files"
)
parser
.
add_argument
(
"-r"
,
"--skip-resampling"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Skip resampling the data (from 48 to 22.05)"
,
)
parser
.
add_argument
(
"-l"
,
"--language"
,
type
=
str
,
choices
=
[
"en-US"
,
"ja-JP"
],
default
=
"en-US"
,
help
=
"The language to download"
)
parser
.
add_argument
(
"-g"
,
"--gender"
,
type
=
str
,
choices
=
[
"male"
,
"female"
],
default
=
"female"
,
help
=
"The gender of the speaker to download"
,
)
parser
.
add_argument
(
"-o"
,
"--output_dir"
,
type
=
str
,
default
=
"data"
,
help
=
"Place to store the converted data. Top-level only, the subdirectory will be created"
,
)
return
parser
.
parse_args
()
def
process_text
(
infile
,
outpath
:
Path
):
outmode
=
"w"
if
infile
.
endswith
(
"dev.txt"
):
outfile
=
outpath
/
"valid.txt"
elif
infile
.
endswith
(
"eval.txt"
):
outfile
=
outpath
/
"test.txt"
else
:
outfile
=
outpath
/
"train.txt"
if
outfile
.
exists
():
outmode
=
"a"
with
(
open
(
infile
,
encoding
=
"utf-8"
)
as
inf
,
open
(
outfile
,
outmode
,
encoding
=
"utf-8"
)
as
of
,
):
for
line
in
inf
.
readlines
():
line
=
line
.
strip
()
fileid
,
rest
=
line
.
split
(
" "
,
maxsplit
=
1
)
outfile
=
str
(
outpath
/
f
"
{
fileid
}
.wav"
)
of
.
write
(
f
"
{
outfile
}
|
{
rest
}
\n
"
)
def
process_files
(
zipfile
,
outpath
,
resample
=
True
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
for
filename
in
tqdm
(
_extract_zip
(
zipfile
,
tmpdirname
)):
if
not
filename
.
startswith
(
tmpdirname
):
filename
=
os
.
path
.
join
(
tmpdirname
,
filename
)
if
filename
.
endswith
(
".txt"
):
process_text
(
filename
,
outpath
)
elif
filename
.
endswith
(
".wav"
):
filepart
=
filename
.
rsplit
(
"/"
,
maxsplit
=
1
)[
-
1
]
outfile
=
str
(
outpath
/
filepart
)
arr
,
sr
=
torchaudio
.
load
(
filename
)
if
resample
:
arr
=
torchaudio
.
functional
.
resample
(
arr
,
orig_freq
=
sr
,
new_freq
=
22050
)
torchaudio
.
save
(
outfile
,
arr
,
22050
)
else
:
continue
def
main
():
args
=
get_args
()
save_dir
=
None
if
args
.
save_dir
:
save_dir
=
Path
(
args
.
save_dir
)
if
not
save_dir
.
is_dir
():
save_dir
.
mkdir
()
if
not
args
.
output_dir
:
print
(
"output directory not specified, exiting"
)
sys
.
exit
(
1
)
URL
=
URLS
[
args
.
language
][
args
.
gender
]
dirname
=
f
"hi-fi_
{
args
.
language
}
_
{
args
.
gender
}
"
outbasepath
=
Path
(
args
.
output_dir
)
if
not
outbasepath
.
is_dir
():
outbasepath
.
mkdir
()
outpath
=
outbasepath
/
dirname
if
not
outpath
.
is_dir
():
outpath
.
mkdir
()
resample
=
True
if
args
.
skip_resampling
:
resample
=
False
if
save_dir
:
zipname
=
URL
.
rsplit
(
"/"
,
maxsplit
=
1
)[
-
1
]
zipfile
=
save_dir
/
zipname
if
not
zipfile
.
exists
():
download_url_to_file
(
URL
,
zipfile
,
progress
=
True
)
process_files
(
zipfile
,
outpath
,
resample
)
else
:
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".zip"
,
delete
=
True
)
as
zf
:
download_url_to_file
(
URL
,
zf
.
name
,
progress
=
True
)
process_files
(
zf
.
name
,
outpath
,
resample
)
if
__name__
==
"__main__"
:
main
()
third_party/Matcha-TTS/matcha/utils/data/ljspeech.py
0 → 100644
View file @
0112b0f0
#!/usr/bin/env python
import
argparse
import
random
import
tempfile
from
pathlib
import
Path
from
torch.hub
import
download_url_to_file
from
matcha.utils.data.utils
import
_extract_tar
URL
=
"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
INFO_PAGE
=
"https://keithito.com/LJ-Speech-Dataset/"
LICENCE
=
"Public domain (LibriVox copyright disclaimer)"
CITATION
=
"""
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {
\\
url{https://keithito.com/LJ-Speech-Dataset/}},
year = 2017
}
"""
def
decision
():
return
random
.
random
()
<
0.98
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-s"
,
"--save-dir"
,
type
=
str
,
default
=
None
,
help
=
"Place to store the downloaded zip files"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
nargs
=
"?"
,
default
=
"data"
,
help
=
"Place to store the converted data (subdirectory LJSpeech-1.1 will be created)"
,
)
return
parser
.
parse_args
()
def
process_csv
(
ljpath
:
Path
):
if
(
ljpath
/
"metadata.csv"
).
exists
():
basepath
=
ljpath
elif
(
ljpath
/
"LJSpeech-1.1"
/
"metadata.csv"
).
exists
():
basepath
=
ljpath
/
"LJSpeech-1.1"
csvpath
=
basepath
/
"metadata.csv"
wavpath
=
basepath
/
"wavs"
with
(
open
(
csvpath
,
encoding
=
"utf-8"
)
as
csvf
,
open
(
basepath
/
"train.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
tf
,
open
(
basepath
/
"val.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
vf
,
):
for
line
in
csvf
.
readlines
():
line
=
line
.
strip
()
parts
=
line
.
split
(
"|"
)
wavfile
=
str
(
wavpath
/
f
"
{
parts
[
0
]
}
.wav"
)
if
decision
():
tf
.
write
(
f
"
{
wavfile
}
|
{
parts
[
1
]
}
\n
"
)
else
:
vf
.
write
(
f
"
{
wavfile
}
|
{
parts
[
1
]
}
\n
"
)
def
main
():
args
=
get_args
()
save_dir
=
None
if
args
.
save_dir
:
save_dir
=
Path
(
args
.
save_dir
)
if
not
save_dir
.
is_dir
():
save_dir
.
mkdir
()
outpath
=
Path
(
args
.
output_dir
)
if
not
outpath
.
is_dir
():
outpath
.
mkdir
()
if
save_dir
:
tarname
=
URL
.
rsplit
(
"/"
,
maxsplit
=
1
)[
-
1
]
tarfile
=
save_dir
/
tarname
if
not
tarfile
.
exists
():
download_url_to_file
(
URL
,
str
(
tarfile
),
progress
=
True
)
_extract_tar
(
tarfile
,
outpath
)
process_csv
(
outpath
)
else
:
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".tar.bz2"
,
delete
=
True
)
as
zf
:
download_url_to_file
(
URL
,
zf
.
name
,
progress
=
True
)
_extract_tar
(
zf
.
name
,
outpath
)
process_csv
(
outpath
)
if
__name__
==
"__main__"
:
main
()
third_party/Matcha-TTS/matcha/utils/data/utils.py
0 → 100644
View file @
0112b0f0
# taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala)
# Licence: BSD 2-Clause
# pylint: disable=C0123
import
logging
import
os
import
tarfile
import
zipfile
from
pathlib
import
Path
from
typing
import
Any
,
List
,
Optional
,
Union
_LG
=
logging
.
getLogger
(
__name__
)
def
_extract_tar
(
from_path
:
Union
[
str
,
Path
],
to_path
:
Optional
[
str
]
=
None
,
overwrite
:
bool
=
False
)
->
List
[
str
]:
if
type
(
from_path
)
is
Path
:
from_path
=
str
(
Path
)
if
to_path
is
None
:
to_path
=
os
.
path
.
dirname
(
from_path
)
with
tarfile
.
open
(
from_path
,
"r"
)
as
tar
:
files
=
[]
for
file_
in
tar
:
# type: Any
file_path
=
os
.
path
.
join
(
to_path
,
file_
.
name
)
if
file_
.
isfile
():
files
.
append
(
file_path
)
if
os
.
path
.
exists
(
file_path
):
_LG
.
info
(
"%s already extracted."
,
file_path
)
if
not
overwrite
:
continue
tar
.
extract
(
file_
,
to_path
)
return
files
def
_extract_zip
(
from_path
:
Union
[
str
,
Path
],
to_path
:
Optional
[
str
]
=
None
,
overwrite
:
bool
=
False
)
->
List
[
str
]:
if
type
(
from_path
)
is
Path
:
from_path
=
str
(
Path
)
if
to_path
is
None
:
to_path
=
os
.
path
.
dirname
(
from_path
)
with
zipfile
.
ZipFile
(
from_path
,
"r"
)
as
zfile
:
files
=
zfile
.
namelist
()
for
file_
in
files
:
file_path
=
os
.
path
.
join
(
to_path
,
file_
)
if
os
.
path
.
exists
(
file_path
):
_LG
.
info
(
"%s already extracted."
,
file_path
)
if
not
overwrite
:
continue
zfile
.
extract
(
file_
,
to_path
)
return
files
third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py
0 → 100644
View file @
0112b0f0
r
"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import
argparse
import
json
import
os
import
sys
from
pathlib
import
Path
import
rootutils
import
torch
from
hydra
import
compose
,
initialize
from
omegaconf
import
open_dict
from
tqdm.auto
import
tqdm
from
matcha.data.text_mel_datamodule
import
TextMelDataModule
from
matcha.utils.logging_utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
def
compute_data_statistics
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
out_channels
:
int
):
"""Generate data mean and standard deviation helpful in data normalisation
Args:
data_loader (torch.utils.data.Dataloader): _description_
out_channels (int): mel spectrogram channels
"""
total_mel_sum
=
0
total_mel_sq_sum
=
0
total_mel_len
=
0
for
batch
in
tqdm
(
data_loader
,
leave
=
False
):
mels
=
batch
[
"y"
]
mel_lengths
=
batch
[
"y_lengths"
]
total_mel_len
+=
torch
.
sum
(
mel_lengths
)
total_mel_sum
+=
torch
.
sum
(
mels
)
total_mel_sq_sum
+=
torch
.
sum
(
torch
.
pow
(
mels
,
2
))
data_mean
=
total_mel_sum
/
(
total_mel_len
*
out_channels
)
data_std
=
torch
.
sqrt
((
total_mel_sq_sum
/
(
total_mel_len
*
out_channels
))
-
torch
.
pow
(
data_mean
,
2
))
return
{
"mel_mean"
:
data_mean
.
item
(),
"mel_std"
:
data_std
.
item
()}
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-config"
,
type
=
str
,
default
=
"vctk.yaml"
,
help
=
"The name of the yaml config file under configs/data"
,
)
parser
.
add_argument
(
"-b"
,
"--batch-size"
,
type
=
int
,
default
=
"256"
,
help
=
"Can have increased batch size for faster computation"
,
)
parser
.
add_argument
(
"-f"
,
"--force"
,
action
=
"store_true"
,
default
=
False
,
required
=
False
,
help
=
"force overwrite the file"
,
)
args
=
parser
.
parse_args
()
output_file
=
Path
(
args
.
input_config
).
with_suffix
(
".json"
)
if
os
.
path
.
exists
(
output_file
)
and
not
args
.
force
:
print
(
"File already exists. Use -f to force overwrite"
)
sys
.
exit
(
1
)
with
initialize
(
version_base
=
"1.3"
,
config_path
=
"../../configs/data"
):
cfg
=
compose
(
config_name
=
args
.
input_config
,
return_hydra_config
=
True
,
overrides
=
[])
root_path
=
rootutils
.
find_root
(
search_from
=
__file__
,
indicator
=
".project-root"
)
with
open_dict
(
cfg
):
del
cfg
[
"hydra"
]
del
cfg
[
"_target_"
]
cfg
[
"data_statistics"
]
=
None
cfg
[
"seed"
]
=
1234
cfg
[
"batch_size"
]
=
args
.
batch_size
cfg
[
"train_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"train_filelist_path"
]))
cfg
[
"valid_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"valid_filelist_path"
]))
cfg
[
"load_durations"
]
=
False
text_mel_datamodule
=
TextMelDataModule
(
**
cfg
)
text_mel_datamodule
.
setup
()
data_loader
=
text_mel_datamodule
.
train_dataloader
()
log
.
info
(
"Dataloader loaded! Now computing stats..."
)
params
=
compute_data_statistics
(
data_loader
,
cfg
[
"n_feats"
])
print
(
params
)
with
open
(
output_file
,
"w"
,
encoding
=
"utf-8"
)
as
dumpfile
:
json
.
dump
(
params
,
dumpfile
)
if
__name__
==
"__main__"
:
main
()
third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py
0 → 100644
View file @
0112b0f0
r
"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import
argparse
import
json
import
os
import
sys
from
pathlib
import
Path
import
lightning
import
numpy
as
np
import
rootutils
import
torch
from
hydra
import
compose
,
initialize
from
omegaconf
import
open_dict
from
torch
import
nn
from
tqdm.auto
import
tqdm
from
matcha.cli
import
get_device
from
matcha.data.text_mel_datamodule
import
TextMelDataModule
from
matcha.models.matcha_tts
import
MatchaTTS
from
matcha.utils.logging_utils
import
pylogger
from
matcha.utils.utils
import
get_phoneme_durations
log
=
pylogger
.
get_pylogger
(
__name__
)
def
save_durations_to_folder
(
attn
:
torch
.
Tensor
,
x_length
:
int
,
y_length
:
int
,
filepath
:
str
,
output_folder
:
Path
,
text
:
str
):
durations
=
attn
.
squeeze
().
sum
(
1
)[:
x_length
].
numpy
()
durations_json
=
get_phoneme_durations
(
durations
,
text
)
output
=
output_folder
/
Path
(
filepath
).
name
.
replace
(
".wav"
,
".npy"
)
with
open
(
output
.
with_suffix
(
".json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
durations_json
,
f
,
indent
=
4
,
ensure_ascii
=
False
)
np
.
save
(
output
,
durations
)
@
torch
.
inference_mode
()
def
compute_durations
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
model
:
nn
.
Module
,
device
:
torch
.
device
,
output_folder
):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for
batch
in
tqdm
(
data_loader
,
desc
=
"🍵 Computing durations 🍵:"
):
x
,
x_lengths
=
batch
[
"x"
],
batch
[
"x_lengths"
]
y
,
y_lengths
=
batch
[
"y"
],
batch
[
"y_lengths"
]
spks
=
batch
[
"spks"
]
x
=
x
.
to
(
device
)
y
=
y
.
to
(
device
)
x_lengths
=
x_lengths
.
to
(
device
)
y_lengths
=
y_lengths
.
to
(
device
)
spks
=
spks
.
to
(
device
)
if
spks
is
not
None
else
None
_
,
_
,
_
,
attn
=
model
(
x
=
x
,
x_lengths
=
x_lengths
,
y
=
y
,
y_lengths
=
y_lengths
,
spks
=
spks
,
)
attn
=
attn
.
cpu
()
for
i
in
range
(
attn
.
shape
[
0
]):
save_durations_to_folder
(
attn
[
i
],
x_lengths
[
i
].
item
(),
y_lengths
[
i
].
item
(),
batch
[
"filepaths"
][
i
],
output_folder
,
batch
[
"x_texts"
][
i
],
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-config"
,
type
=
str
,
default
=
"ljspeech.yaml"
,
help
=
"The name of the yaml config file under configs/data"
,
)
parser
.
add_argument
(
"-b"
,
"--batch-size"
,
type
=
int
,
default
=
"32"
,
help
=
"Can have increased batch size for faster computation"
,
)
parser
.
add_argument
(
"-f"
,
"--force"
,
action
=
"store_true"
,
default
=
False
,
required
=
False
,
help
=
"force overwrite the file"
,
)
parser
.
add_argument
(
"-c"
,
"--checkpoint_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the checkpoint file to load the model from"
,
)
parser
.
add_argument
(
"-o"
,
"--output-folder"
,
type
=
str
,
default
=
None
,
help
=
"Output folder to save the data statistics"
,
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
help
=
"Use CPU for inference, not recommended (default: use GPU if available)"
)
args
=
parser
.
parse_args
()
with
initialize
(
version_base
=
"1.3"
,
config_path
=
"../../configs/data"
):
cfg
=
compose
(
config_name
=
args
.
input_config
,
return_hydra_config
=
True
,
overrides
=
[])
root_path
=
rootutils
.
find_root
(
search_from
=
__file__
,
indicator
=
".project-root"
)
with
open_dict
(
cfg
):
del
cfg
[
"hydra"
]
del
cfg
[
"_target_"
]
cfg
[
"seed"
]
=
1234
cfg
[
"batch_size"
]
=
args
.
batch_size
cfg
[
"train_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"train_filelist_path"
]))
cfg
[
"valid_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"valid_filelist_path"
]))
cfg
[
"load_durations"
]
=
False
if
args
.
output_folder
is
not
None
:
output_folder
=
Path
(
args
.
output_folder
)
else
:
output_folder
=
Path
(
cfg
[
"train_filelist_path"
]).
parent
/
"durations"
print
(
f
"Output folder set to:
{
output_folder
}
"
)
if
os
.
path
.
exists
(
output_folder
)
and
not
args
.
force
:
print
(
"Folder already exists. Use -f to force overwrite"
)
sys
.
exit
(
1
)
output_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
print
(
f
"Preprocessing:
{
cfg
[
'name'
]
}
from training filelist:
{
cfg
[
'train_filelist_path'
]
}
"
)
print
(
"Loading model..."
)
device
=
get_device
(
args
)
model
=
MatchaTTS
.
load_from_checkpoint
(
args
.
checkpoint_path
,
map_location
=
device
)
text_mel_datamodule
=
TextMelDataModule
(
**
cfg
)
text_mel_datamodule
.
setup
()
try
:
print
(
"Computing stats for training set if exists..."
)
train_dataloader
=
text_mel_datamodule
.
train_dataloader
()
compute_durations
(
train_dataloader
,
model
,
device
,
output_folder
)
except
lightning
.
fabric
.
utilities
.
exceptions
.
MisconfigurationException
:
print
(
"No training set found"
)
try
:
print
(
"Computing stats for validation set if exists..."
)
val_dataloader
=
text_mel_datamodule
.
val_dataloader
()
compute_durations
(
val_dataloader
,
model
,
device
,
output_folder
)
except
lightning
.
fabric
.
utilities
.
exceptions
.
MisconfigurationException
:
print
(
"No validation set found"
)
try
:
print
(
"Computing stats for test set if exists..."
)
test_dataloader
=
text_mel_datamodule
.
test_dataloader
()
compute_durations
(
test_dataloader
,
model
,
device
,
output_folder
)
except
lightning
.
fabric
.
utilities
.
exceptions
.
MisconfigurationException
:
print
(
"No test set found"
)
print
(
f
"[+] Done! Data statistics saved to:
{
output_folder
}
"
)
if
__name__
==
"__main__"
:
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main
()
third_party/Matcha-TTS/matcha/utils/instantiators.py
0 → 100644
View file @
0112b0f0
from
typing
import
List
import
hydra
from
lightning
import
Callback
from
lightning.pytorch.loggers
import
Logger
from
omegaconf
import
DictConfig
from
matcha.utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
def
instantiate_callbacks
(
callbacks_cfg
:
DictConfig
)
->
List
[
Callback
]:
"""Instantiates callbacks from config.
:param callbacks_cfg: A DictConfig object containing callback configurations.
:return: A list of instantiated callbacks.
"""
callbacks
:
List
[
Callback
]
=
[]
if
not
callbacks_cfg
:
log
.
warning
(
"No callback configs found! Skipping.."
)
return
callbacks
if
not
isinstance
(
callbacks_cfg
,
DictConfig
):
raise
TypeError
(
"Callbacks config must be a DictConfig!"
)
for
_
,
cb_conf
in
callbacks_cfg
.
items
():
if
isinstance
(
cb_conf
,
DictConfig
)
and
"_target_"
in
cb_conf
:
log
.
info
(
f
"Instantiating callback <
{
cb_conf
.
_target_
}
>"
)
# pylint: disable=protected-access
callbacks
.
append
(
hydra
.
utils
.
instantiate
(
cb_conf
))
return
callbacks
def
instantiate_loggers
(
logger_cfg
:
DictConfig
)
->
List
[
Logger
]:
"""Instantiates loggers from config.
:param logger_cfg: A DictConfig object containing logger configurations.
:return: A list of instantiated loggers.
"""
logger
:
List
[
Logger
]
=
[]
if
not
logger_cfg
:
log
.
warning
(
"No logger configs found! Skipping..."
)
return
logger
if
not
isinstance
(
logger_cfg
,
DictConfig
):
raise
TypeError
(
"Logger config must be a DictConfig!"
)
for
_
,
lg_conf
in
logger_cfg
.
items
():
if
isinstance
(
lg_conf
,
DictConfig
)
and
"_target_"
in
lg_conf
:
log
.
info
(
f
"Instantiating logger <
{
lg_conf
.
_target_
}
>"
)
# pylint: disable=protected-access
logger
.
append
(
hydra
.
utils
.
instantiate
(
lg_conf
))
return
logger
third_party/Matcha-TTS/matcha/utils/logging_utils.py
0 → 100644
View file @
0112b0f0
from
typing
import
Any
,
Dict
from
lightning.pytorch.utilities
import
rank_zero_only
from
omegaconf
import
OmegaConf
from
matcha.utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
@
rank_zero_only
def
log_hyperparameters
(
object_dict
:
Dict
[
str
,
Any
])
->
None
:
"""Controls which config parts are saved by Lightning loggers.
Additionally saves:
- Number of model parameters
:param object_dict: A dictionary containing the following objects:
- `"cfg"`: A DictConfig object containing the main config.
- `"model"`: The Lightning model.
- `"trainer"`: The Lightning trainer.
"""
hparams
=
{}
cfg
=
OmegaConf
.
to_container
(
object_dict
[
"cfg"
])
model
=
object_dict
[
"model"
]
trainer
=
object_dict
[
"trainer"
]
if
not
trainer
.
logger
:
log
.
warning
(
"Logger not found! Skipping hyperparameter logging..."
)
return
hparams
[
"model"
]
=
cfg
[
"model"
]
# save number of model parameters
hparams
[
"model/params/total"
]
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
hparams
[
"model/params/trainable"
]
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
hparams
[
"model/params/non_trainable"
]
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
not
p
.
requires_grad
)
hparams
[
"data"
]
=
cfg
[
"data"
]
hparams
[
"trainer"
]
=
cfg
[
"trainer"
]
hparams
[
"callbacks"
]
=
cfg
.
get
(
"callbacks"
)
hparams
[
"extras"
]
=
cfg
.
get
(
"extras"
)
hparams
[
"task_name"
]
=
cfg
.
get
(
"task_name"
)
hparams
[
"tags"
]
=
cfg
.
get
(
"tags"
)
hparams
[
"ckpt_path"
]
=
cfg
.
get
(
"ckpt_path"
)
hparams
[
"seed"
]
=
cfg
.
get
(
"seed"
)
# send hparams to all loggers
for
logger
in
trainer
.
loggers
:
logger
.
log_hyperparams
(
hparams
)
third_party/Matcha-TTS/matcha/utils/model.py
0 → 100644
View file @
0112b0f0
""" from https://github.com/jaywalnut310/glow-tts """
import
numpy
as
np
import
torch
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
max_length
,
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
def
fix_len_compatibility
(
length
,
num_downsamplings_in_unet
=
2
):
factor
=
torch
.
scalar_tensor
(
2
).
pow
(
num_downsamplings_in_unet
)
length
=
(
length
/
factor
).
ceil
()
*
factor
if
not
torch
.
onnx
.
is_in_onnx_export
():
return
length
.
int
().
item
()
else
:
return
length
def
convert_pad_shape
(
pad_shape
):
inverted_shape
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
inverted_shape
for
item
in
sublist
]
return
pad_shape
def
generate_path
(
duration
,
mask
):
device
=
duration
.
device
b
,
t_x
,
t_y
=
mask
.
shape
cum_duration
=
torch
.
cumsum
(
duration
,
1
)
path
=
torch
.
zeros
(
b
,
t_x
,
t_y
,
dtype
=
mask
.
dtype
).
to
(
device
=
device
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
*
mask
return
path
def
duration_loss
(
logw
,
logw_
,
lengths
):
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
return
loss
def
normalize
(
data
,
mu
,
std
):
if
not
isinstance
(
mu
,
(
float
,
int
)):
if
isinstance
(
mu
,
list
):
mu
=
torch
.
tensor
(
mu
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
mu
,
torch
.
Tensor
):
mu
=
mu
.
to
(
data
.
device
)
elif
isinstance
(
mu
,
np
.
ndarray
):
mu
=
torch
.
from_numpy
(
mu
).
to
(
data
.
device
)
mu
=
mu
.
unsqueeze
(
-
1
)
if
not
isinstance
(
std
,
(
float
,
int
)):
if
isinstance
(
std
,
list
):
std
=
torch
.
tensor
(
std
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
std
,
torch
.
Tensor
):
std
=
std
.
to
(
data
.
device
)
elif
isinstance
(
std
,
np
.
ndarray
):
std
=
torch
.
from_numpy
(
std
).
to
(
data
.
device
)
std
=
std
.
unsqueeze
(
-
1
)
return
(
data
-
mu
)
/
std
def
denormalize
(
data
,
mu
,
std
):
if
not
isinstance
(
mu
,
float
):
if
isinstance
(
mu
,
list
):
mu
=
torch
.
tensor
(
mu
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
mu
,
torch
.
Tensor
):
mu
=
mu
.
to
(
data
.
device
)
elif
isinstance
(
mu
,
np
.
ndarray
):
mu
=
torch
.
from_numpy
(
mu
).
to
(
data
.
device
)
mu
=
mu
.
unsqueeze
(
-
1
)
if
not
isinstance
(
std
,
float
):
if
isinstance
(
std
,
list
):
std
=
torch
.
tensor
(
std
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
std
,
torch
.
Tensor
):
std
=
std
.
to
(
data
.
device
)
elif
isinstance
(
std
,
np
.
ndarray
):
std
=
torch
.
from_numpy
(
std
).
to
(
data
.
device
)
std
=
std
.
unsqueeze
(
-
1
)
return
data
*
std
+
mu
third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
0 → 100644
View file @
0112b0f0
import
numpy
as
np
import
torch
from
matcha.utils.monotonic_align.core
import
maximum_path_c
def
maximum_path
(
value
,
mask
):
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value
=
value
*
mask
device
=
value
.
device
dtype
=
value
.
dtype
value
=
value
.
data
.
cpu
().
numpy
().
astype
(
np
.
float32
)
path
=
np
.
zeros_like
(
value
).
astype
(
np
.
int32
)
mask
=
mask
.
data
.
cpu
().
numpy
()
t_x_max
=
mask
.
sum
(
1
)[:,
0
].
astype
(
np
.
int32
)
t_y_max
=
mask
.
sum
(
2
)[:,
0
].
astype
(
np
.
int32
)
maximum_path_c
(
path
,
value
,
t_x_max
,
t_y_max
)
return
torch
.
from_numpy
(
path
).
to
(
device
=
device
,
dtype
=
dtype
)
third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx
0 → 100644
View file @
0112b0f0
import
numpy
as
np
cimport
cython
cimport
numpy
as
np
from
cython.parallel
import
prange
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
cdef
void
maximum_path_each
(
int
[:,::
1
]
path
,
float
[:,::
1
]
value
,
int
t_x
,
int
t_y
,
float
max_neg_val
)
nogil
:
cdef
int
x
cdef
int
y
cdef
float
v_prev
cdef
float
v_cur
cdef
float
tmp
cdef
int
index
=
t_x
-
1
for
y
in
range
(
t_y
):
for
x
in
range
(
max
(
0
,
t_x
+
y
-
t_y
),
min
(
t_x
,
y
+
1
)):
if
x
==
y
:
v_cur
=
max_neg_val
else
:
v_cur
=
value
[
x
,
y
-
1
]
if
x
==
0
:
if
y
==
0
:
v_prev
=
0.
else
:
v_prev
=
max_neg_val
else
:
v_prev
=
value
[
x
-
1
,
y
-
1
]
value
[
x
,
y
]
=
max
(
v_cur
,
v_prev
)
+
value
[
x
,
y
]
for
y
in
range
(
t_y
-
1
,
-
1
,
-
1
):
path
[
index
,
y
]
=
1
if
index
!=
0
and
(
index
==
y
or
value
[
index
,
y
-
1
]
<
value
[
index
-
1
,
y
-
1
]):
index
=
index
-
1
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
cpdef
void
maximum_path_c
(
int
[:,:,::
1
]
paths
,
float
[:,:,::
1
]
values
,
int
[::
1
]
t_xs
,
int
[::
1
]
t_ys
,
float
max_neg_val
=-
1e9
)
nogil
:
cdef
int
b
=
values
.
shape
[
0
]
cdef
int
i
for
i
in
prange
(
b
,
nogil
=
True
):
maximum_path_each
(
paths
[
i
],
values
[
i
],
t_xs
[
i
],
t_ys
[
i
],
max_neg_val
)
third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py
0 → 100644
View file @
0112b0f0
# from distutils.core import setup
# from Cython.Build import cythonize
# import numpy
# setup(name='monotonic_align',
# ext_modules=cythonize("core.pyx"),
# include_dirs=[numpy.get_include()])
Prev
1
…
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment