Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VITA-Audio_pytorch
Commits
39ac40a9
Commit
39ac40a9
authored
Jun 06, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2747
failed with stages
in 0 seconds
Changes
427
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1556 additions
and
0 deletions
+1556
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/text/numbers.py
...GLM-4-Voice/third_party/Matcha-TTS/matcha/text/numbers.py
+71
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/text/symbols.py
...GLM-4-Voice/third_party/Matcha-TTS/matcha/text/symbols.py
+17
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/train.py
..._party/GLM-4-Voice/third_party/Matcha-TTS/matcha/train.py
+122
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/__init__.py
...M-4-Voice/third_party/Matcha-TTS/matcha/utils/__init__.py
+5
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/audio.py
.../GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/audio.py
+82
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/__init__.py
...oice/third_party/Matcha-TTS/matcha/utils/data/__init__.py
+0
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/hificaptain.py
...e/third_party/Matcha-TTS/matcha/utils/data/hificaptain.py
+148
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/ljspeech.py
...oice/third_party/Matcha-TTS/matcha/utils/data/ljspeech.py
+97
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/utils.py
...4-Voice/third_party/Matcha-TTS/matcha/utils/data/utils.py
+53
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py
...party/Matcha-TTS/matcha/utils/generate_data_statistics.py
+110
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py
...tcha-TTS/matcha/utils/get_durations_from_trained_model.py
+195
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/instantiators.py
...oice/third_party/Matcha-TTS/matcha/utils/instantiators.py
+56
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/logging_utils.py
...oice/third_party/Matcha-TTS/matcha/utils/logging_utils.py
+53
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/model.py
.../GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/model.py
+90
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
...party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
+22
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx
...rd_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx
+47
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py
...rd_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py
+7
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/pylogger.py
...M-4-Voice/third_party/Matcha-TTS/matcha/utils/pylogger.py
+21
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/rich_utils.py
...4-Voice/third_party/Matcha-TTS/matcha/utils/rich_utils.py
+101
-0
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/utils.py
.../GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/utils.py
+259
-0
No files found.
Too many changes to show.
To preserve performance only
427 of 427+
files are displayed.
Plain diff
Email patch
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/text/numbers.py
0 → 100644
View file @
39ac40a9
""" from https://github.com/keithito/tacotron """
import
re
import
inflect
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
"([0-9][0-9\,]+[0-9])"
)
_decimal_number_re
=
re
.
compile
(
r
"([0-9]+\.[0-9]+)"
)
_pounds_re
=
re
.
compile
(
r
"£([0-9\,]*[0-9]+)"
)
_dollars_re
=
re
.
compile
(
r
"\$([0-9\.\,]*[0-9]+)"
)
_ordinal_re
=
re
.
compile
(
r
"[0-9]+(st|nd|rd|th)"
)
_number_re
=
re
.
compile
(
r
"[0-9]+"
)
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
","
,
""
)
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
"."
,
" point "
)
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
"."
)
if
len
(
parts
)
>
2
:
return
match
+
" dollars"
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
f
"
{
dollars
}
{
dollar_unit
}
,
{
cents
}
{
cent_unit
}
"
elif
dollars
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
return
f
"
{
dollars
}
{
dollar_unit
}
"
elif
cents
:
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
f
"
{
cents
}
{
cent_unit
}
"
else
:
return
"zero dollars"
def
_expand_ordinal
(
m
):
return
_inflect
.
number_to_words
(
m
.
group
(
0
))
def
_expand_number
(
m
):
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
return
"two thousand"
elif
num
>
2000
and
num
<
2010
:
return
"two thousand "
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
" hundred"
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
,
zero
=
"oh"
,
group
=
2
).
replace
(
", "
,
" "
)
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
)
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
"\1 pounds"
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/text/symbols.py
0 → 100644
View file @
39ac40a9
""" from https://github.com/keithito/tacotron
Defines the set of symbols used in text input to the model.
"""
_pad
=
"_"
_punctuation
=
';:,.!?¡¿—…"«»“” '
_letters
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa
=
(
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
)
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_punctuation
)
+
list
(
_letters
)
+
list
(
_letters_ipa
)
# Special symbol ids
SPACE_ID
=
symbols
.
index
(
" "
)
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/train.py
0 → 100644
View file @
39ac40a9
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
hydra
import
lightning
as
L
import
rootutils
from
lightning
import
Callback
,
LightningDataModule
,
LightningModule
,
Trainer
from
lightning.pytorch.loggers
import
Logger
from
omegaconf
import
DictConfig
from
matcha
import
utils
rootutils
.
setup_root
(
__file__
,
indicator
=
".project-root"
,
pythonpath
=
True
)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
log
=
utils
.
get_pylogger
(
__name__
)
@
utils
.
task_wrapper
def
train
(
cfg
:
DictConfig
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if
cfg
.
get
(
"seed"
):
L
.
seed_everything
(
cfg
.
seed
,
workers
=
True
)
log
.
info
(
f
"Instantiating datamodule <
{
cfg
.
data
.
_target_
}
>"
)
# pylint: disable=protected-access
datamodule
:
LightningDataModule
=
hydra
.
utils
.
instantiate
(
cfg
.
data
)
log
.
info
(
f
"Instantiating model <
{
cfg
.
model
.
_target_
}
>"
)
# pylint: disable=protected-access
model
:
LightningModule
=
hydra
.
utils
.
instantiate
(
cfg
.
model
)
log
.
info
(
"Instantiating callbacks..."
)
callbacks
:
List
[
Callback
]
=
utils
.
instantiate_callbacks
(
cfg
.
get
(
"callbacks"
))
log
.
info
(
"Instantiating loggers..."
)
logger
:
List
[
Logger
]
=
utils
.
instantiate_loggers
(
cfg
.
get
(
"logger"
))
log
.
info
(
f
"Instantiating trainer <
{
cfg
.
trainer
.
_target_
}
>"
)
# pylint: disable=protected-access
trainer
:
Trainer
=
hydra
.
utils
.
instantiate
(
cfg
.
trainer
,
callbacks
=
callbacks
,
logger
=
logger
)
object_dict
=
{
"cfg"
:
cfg
,
"datamodule"
:
datamodule
,
"model"
:
model
,
"callbacks"
:
callbacks
,
"logger"
:
logger
,
"trainer"
:
trainer
,
}
if
logger
:
log
.
info
(
"Logging hyperparameters!"
)
utils
.
log_hyperparameters
(
object_dict
)
if
cfg
.
get
(
"train"
):
log
.
info
(
"Starting training!"
)
trainer
.
fit
(
model
=
model
,
datamodule
=
datamodule
,
ckpt_path
=
cfg
.
get
(
"ckpt_path"
))
train_metrics
=
trainer
.
callback_metrics
if
cfg
.
get
(
"test"
):
log
.
info
(
"Starting testing!"
)
ckpt_path
=
trainer
.
checkpoint_callback
.
best_model_path
if
ckpt_path
==
""
:
log
.
warning
(
"Best ckpt not found! Using current weights for testing..."
)
ckpt_path
=
None
trainer
.
test
(
model
=
model
,
datamodule
=
datamodule
,
ckpt_path
=
ckpt_path
)
log
.
info
(
f
"Best ckpt path:
{
ckpt_path
}
"
)
test_metrics
=
trainer
.
callback_metrics
# merge train and test metrics
metric_dict
=
{
**
train_metrics
,
**
test_metrics
}
return
metric_dict
,
object_dict
@
hydra
.
main
(
version_base
=
"1.3"
,
config_path
=
"../configs"
,
config_name
=
"train.yaml"
)
def
main
(
cfg
:
DictConfig
)
->
Optional
[
float
]:
"""Main entry point for training.
:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils
.
extras
(
cfg
)
# train the model
metric_dict
,
_
=
train
(
cfg
)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value
=
utils
.
get_metric_value
(
metric_dict
=
metric_dict
,
metric_name
=
cfg
.
get
(
"optimized_metric"
))
# return optimized metric
return
metric_value
if
__name__
==
"__main__"
:
main
()
# pylint: disable=no-value-for-parameter
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/__init__.py
0 → 100644
View file @
39ac40a9
from
matcha.utils.instantiators
import
instantiate_callbacks
,
instantiate_loggers
from
matcha.utils.logging_utils
import
log_hyperparameters
from
matcha.utils.pylogger
import
get_pylogger
from
matcha.utils.rich_utils
import
enforce_tags
,
print_config_tree
from
matcha.utils.utils
import
extras
,
get_metric_value
,
task_wrapper
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/audio.py
0 → 100644
View file @
39ac40a9
import
numpy
as
np
import
torch
import
torch.utils.data
from
librosa.filters
import
mel
as
librosa_mel_fn
from
scipy.io.wavfile
import
read
MAX_WAV_VALUE
=
32768.0
def
load_wav
(
full_path
):
sampling_rate
,
data
=
read
(
full_path
)
return
data
,
sampling_rate
def
dynamic_range_compression
(
x
,
C
=
1
,
clip_val
=
1e-5
):
return
np
.
log
(
np
.
clip
(
x
,
a_min
=
clip_val
,
a_max
=
None
)
*
C
)
def
dynamic_range_decompression
(
x
,
C
=
1
):
return
np
.
exp
(
x
)
/
C
def
dynamic_range_compression_torch
(
x
,
C
=
1
,
clip_val
=
1e-5
):
return
torch
.
log
(
torch
.
clamp
(
x
,
min
=
clip_val
)
*
C
)
def
dynamic_range_decompression_torch
(
x
,
C
=
1
):
return
torch
.
exp
(
x
)
/
C
def
spectral_normalize_torch
(
magnitudes
):
output
=
dynamic_range_compression_torch
(
magnitudes
)
return
output
def
spectral_de_normalize_torch
(
magnitudes
):
output
=
dynamic_range_decompression_torch
(
magnitudes
)
return
output
mel_basis
=
{}
hann_window
=
{}
def
mel_spectrogram
(
y
,
n_fft
,
num_mels
,
sampling_rate
,
hop_size
,
win_size
,
fmin
,
fmax
,
center
=
False
):
if
torch
.
min
(
y
)
<
-
1.0
:
print
(
"min value is "
,
torch
.
min
(
y
))
if
torch
.
max
(
y
)
>
1.0
:
print
(
"max value is "
,
torch
.
max
(
y
))
global
mel_basis
,
hann_window
# pylint: disable=global-statement,global-variable-not-assigned
if
f
"
{
str
(
fmax
)
}
_
{
str
(
y
.
device
)
}
"
not
in
mel_basis
:
mel
=
librosa_mel_fn
(
sr
=
sampling_rate
,
n_fft
=
n_fft
,
n_mels
=
num_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
mel_basis
[
str
(
fmax
)
+
"_"
+
str
(
y
.
device
)]
=
torch
.
from_numpy
(
mel
).
float
().
to
(
y
.
device
)
hann_window
[
str
(
y
.
device
)]
=
torch
.
hann_window
(
win_size
).
to
(
y
.
device
)
y
=
torch
.
nn
.
functional
.
pad
(
y
.
unsqueeze
(
1
),
(
int
((
n_fft
-
hop_size
)
/
2
),
int
((
n_fft
-
hop_size
)
/
2
)),
mode
=
"reflect"
)
y
=
y
.
squeeze
(
1
)
spec
=
torch
.
view_as_real
(
torch
.
stft
(
y
,
n_fft
,
hop_length
=
hop_size
,
win_length
=
win_size
,
window
=
hann_window
[
str
(
y
.
device
)],
center
=
center
,
pad_mode
=
"reflect"
,
normalized
=
False
,
onesided
=
True
,
return_complex
=
True
,
)
)
spec
=
torch
.
sqrt
(
spec
.
pow
(
2
).
sum
(
-
1
)
+
(
1e-9
))
spec
=
torch
.
matmul
(
mel_basis
[
str
(
fmax
)
+
"_"
+
str
(
y
.
device
)],
spec
)
spec
=
spectral_normalize_torch
(
spec
)
return
spec
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/__init__.py
0 → 100644
View file @
39ac40a9
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/hificaptain.py
0 → 100644
View file @
39ac40a9
#!/usr/bin/env python
import
argparse
import
os
import
sys
import
tempfile
from
pathlib
import
Path
import
torchaudio
from
torch.hub
import
download_url_to_file
from
tqdm
import
tqdm
from
matcha.utils.data.utils
import
_extract_zip
URLS
=
{
"en-US"
:
{
"female"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_F.zip"
,
"male"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_en-US_M.zip"
,
},
"ja-JP"
:
{
"female"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_F.zip"
,
"male"
:
"https://ast-astrec.nict.go.jp/release/hi-fi-captain/hfc_ja-JP_M.zip"
,
},
}
INFO_PAGE
=
"https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/"
# On their website they say "We NICT open-sourced Hi-Fi-CAPTAIN",
# but they use this very-much-not-open-source licence.
# Dunno if this is open washing or stupidity.
LICENCE
=
"CC BY-NC-SA 4.0"
# I'd normally put the citation here. It's on their website.
# Boo to non-open-source stuff.
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-s"
,
"--save-dir"
,
type
=
str
,
default
=
None
,
help
=
"Place to store the downloaded zip files"
)
parser
.
add_argument
(
"-r"
,
"--skip-resampling"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Skip resampling the data (from 48 to 22.05)"
,
)
parser
.
add_argument
(
"-l"
,
"--language"
,
type
=
str
,
choices
=
[
"en-US"
,
"ja-JP"
],
default
=
"en-US"
,
help
=
"The language to download"
)
parser
.
add_argument
(
"-g"
,
"--gender"
,
type
=
str
,
choices
=
[
"male"
,
"female"
],
default
=
"female"
,
help
=
"The gender of the speaker to download"
,
)
parser
.
add_argument
(
"-o"
,
"--output_dir"
,
type
=
str
,
default
=
"data"
,
help
=
"Place to store the converted data. Top-level only, the subdirectory will be created"
,
)
return
parser
.
parse_args
()
def
process_text
(
infile
,
outpath
:
Path
):
outmode
=
"w"
if
infile
.
endswith
(
"dev.txt"
):
outfile
=
outpath
/
"valid.txt"
elif
infile
.
endswith
(
"eval.txt"
):
outfile
=
outpath
/
"test.txt"
else
:
outfile
=
outpath
/
"train.txt"
if
outfile
.
exists
():
outmode
=
"a"
with
(
open
(
infile
,
encoding
=
"utf-8"
)
as
inf
,
open
(
outfile
,
outmode
,
encoding
=
"utf-8"
)
as
of
,
):
for
line
in
inf
.
readlines
():
line
=
line
.
strip
()
fileid
,
rest
=
line
.
split
(
" "
,
maxsplit
=
1
)
outfile
=
str
(
outpath
/
f
"
{
fileid
}
.wav"
)
of
.
write
(
f
"
{
outfile
}
|
{
rest
}
\n
"
)
def
process_files
(
zipfile
,
outpath
,
resample
=
True
):
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
for
filename
in
tqdm
(
_extract_zip
(
zipfile
,
tmpdirname
)):
if
not
filename
.
startswith
(
tmpdirname
):
filename
=
os
.
path
.
join
(
tmpdirname
,
filename
)
if
filename
.
endswith
(
".txt"
):
process_text
(
filename
,
outpath
)
elif
filename
.
endswith
(
".wav"
):
filepart
=
filename
.
rsplit
(
"/"
,
maxsplit
=
1
)[
-
1
]
outfile
=
str
(
outpath
/
filepart
)
arr
,
sr
=
torchaudio
.
load
(
filename
)
if
resample
:
arr
=
torchaudio
.
functional
.
resample
(
arr
,
orig_freq
=
sr
,
new_freq
=
22050
)
torchaudio
.
save
(
outfile
,
arr
,
22050
)
else
:
continue
def
main
():
args
=
get_args
()
save_dir
=
None
if
args
.
save_dir
:
save_dir
=
Path
(
args
.
save_dir
)
if
not
save_dir
.
is_dir
():
save_dir
.
mkdir
()
if
not
args
.
output_dir
:
print
(
"output directory not specified, exiting"
)
sys
.
exit
(
1
)
URL
=
URLS
[
args
.
language
][
args
.
gender
]
dirname
=
f
"hi-fi_
{
args
.
language
}
_
{
args
.
gender
}
"
outbasepath
=
Path
(
args
.
output_dir
)
if
not
outbasepath
.
is_dir
():
outbasepath
.
mkdir
()
outpath
=
outbasepath
/
dirname
if
not
outpath
.
is_dir
():
outpath
.
mkdir
()
resample
=
True
if
args
.
skip_resampling
:
resample
=
False
if
save_dir
:
zipname
=
URL
.
rsplit
(
"/"
,
maxsplit
=
1
)[
-
1
]
zipfile
=
save_dir
/
zipname
if
not
zipfile
.
exists
():
download_url_to_file
(
URL
,
zipfile
,
progress
=
True
)
process_files
(
zipfile
,
outpath
,
resample
)
else
:
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".zip"
,
delete
=
True
)
as
zf
:
download_url_to_file
(
URL
,
zf
.
name
,
progress
=
True
)
process_files
(
zf
.
name
,
outpath
,
resample
)
if
__name__
==
"__main__"
:
main
()
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/ljspeech.py
0 → 100644
View file @
39ac40a9
#!/usr/bin/env python
import
argparse
import
random
import
tempfile
from
pathlib
import
Path
from
torch.hub
import
download_url_to_file
from
matcha.utils.data.utils
import
_extract_tar
URL
=
"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
INFO_PAGE
=
"https://keithito.com/LJ-Speech-Dataset/"
LICENCE
=
"Public domain (LibriVox copyright disclaimer)"
CITATION
=
"""
@misc{ljspeech17,
author = {Keith Ito and Linda Johnson},
title = {The LJ Speech Dataset},
howpublished = {
\\
url{https://keithito.com/LJ-Speech-Dataset/}},
year = 2017
}
"""
def
decision
():
return
random
.
random
()
<
0.98
def
get_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-s"
,
"--save-dir"
,
type
=
str
,
default
=
None
,
help
=
"Place to store the downloaded zip files"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
nargs
=
"?"
,
default
=
"data"
,
help
=
"Place to store the converted data (subdirectory LJSpeech-1.1 will be created)"
,
)
return
parser
.
parse_args
()
def
process_csv
(
ljpath
:
Path
):
if
(
ljpath
/
"metadata.csv"
).
exists
():
basepath
=
ljpath
elif
(
ljpath
/
"LJSpeech-1.1"
/
"metadata.csv"
).
exists
():
basepath
=
ljpath
/
"LJSpeech-1.1"
csvpath
=
basepath
/
"metadata.csv"
wavpath
=
basepath
/
"wavs"
with
(
open
(
csvpath
,
encoding
=
"utf-8"
)
as
csvf
,
open
(
basepath
/
"train.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
tf
,
open
(
basepath
/
"val.txt"
,
"w"
,
encoding
=
"utf-8"
)
as
vf
,
):
for
line
in
csvf
.
readlines
():
line
=
line
.
strip
()
parts
=
line
.
split
(
"|"
)
wavfile
=
str
(
wavpath
/
f
"
{
parts
[
0
]
}
.wav"
)
if
decision
():
tf
.
write
(
f
"
{
wavfile
}
|
{
parts
[
1
]
}
\n
"
)
else
:
vf
.
write
(
f
"
{
wavfile
}
|
{
parts
[
1
]
}
\n
"
)
def
main
():
args
=
get_args
()
save_dir
=
None
if
args
.
save_dir
:
save_dir
=
Path
(
args
.
save_dir
)
if
not
save_dir
.
is_dir
():
save_dir
.
mkdir
()
outpath
=
Path
(
args
.
output_dir
)
if
not
outpath
.
is_dir
():
outpath
.
mkdir
()
if
save_dir
:
tarname
=
URL
.
rsplit
(
"/"
,
maxsplit
=
1
)[
-
1
]
tarfile
=
save_dir
/
tarname
if
not
tarfile
.
exists
():
download_url_to_file
(
URL
,
str
(
tarfile
),
progress
=
True
)
_extract_tar
(
tarfile
,
outpath
)
process_csv
(
outpath
)
else
:
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".tar.bz2"
,
delete
=
True
)
as
zf
:
download_url_to_file
(
URL
,
zf
.
name
,
progress
=
True
)
_extract_tar
(
zf
.
name
,
outpath
)
process_csv
(
outpath
)
if
__name__
==
"__main__"
:
main
()
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/data/utils.py
0 → 100644
View file @
39ac40a9
# taken from https://github.com/pytorch/audio/blob/main/src/torchaudio/datasets/utils.py
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala)
# Licence: BSD 2-Clause
# pylint: disable=C0123
import
logging
import
os
import
tarfile
import
zipfile
from
pathlib
import
Path
from
typing
import
Any
,
List
,
Optional
,
Union
_LG
=
logging
.
getLogger
(
__name__
)
def
_extract_tar
(
from_path
:
Union
[
str
,
Path
],
to_path
:
Optional
[
str
]
=
None
,
overwrite
:
bool
=
False
)
->
List
[
str
]:
if
type
(
from_path
)
is
Path
:
from_path
=
str
(
Path
)
if
to_path
is
None
:
to_path
=
os
.
path
.
dirname
(
from_path
)
with
tarfile
.
open
(
from_path
,
"r"
)
as
tar
:
files
=
[]
for
file_
in
tar
:
# type: Any
file_path
=
os
.
path
.
join
(
to_path
,
file_
.
name
)
if
file_
.
isfile
():
files
.
append
(
file_path
)
if
os
.
path
.
exists
(
file_path
):
_LG
.
info
(
"%s already extracted."
,
file_path
)
if
not
overwrite
:
continue
tar
.
extract
(
file_
,
to_path
)
return
files
def
_extract_zip
(
from_path
:
Union
[
str
,
Path
],
to_path
:
Optional
[
str
]
=
None
,
overwrite
:
bool
=
False
)
->
List
[
str
]:
if
type
(
from_path
)
is
Path
:
from_path
=
str
(
Path
)
if
to_path
is
None
:
to_path
=
os
.
path
.
dirname
(
from_path
)
with
zipfile
.
ZipFile
(
from_path
,
"r"
)
as
zfile
:
files
=
zfile
.
namelist
()
for
file_
in
files
:
file_path
=
os
.
path
.
join
(
to_path
,
file_
)
if
os
.
path
.
exists
(
file_path
):
_LG
.
info
(
"%s already extracted."
,
file_path
)
if
not
overwrite
:
continue
zfile
.
extract
(
file_
,
to_path
)
return
files
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py
0 → 100644
View file @
39ac40a9
r
"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import
argparse
import
json
import
os
import
sys
from
pathlib
import
Path
import
rootutils
import
torch
from
hydra
import
compose
,
initialize
from
omegaconf
import
open_dict
from
tqdm.auto
import
tqdm
from
matcha.data.text_mel_datamodule
import
TextMelDataModule
from
matcha.utils.logging_utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
def
compute_data_statistics
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
out_channels
:
int
):
"""Generate data mean and standard deviation helpful in data normalisation
Args:
data_loader (torch.utils.data.Dataloader): _description_
out_channels (int): mel spectrogram channels
"""
total_mel_sum
=
0
total_mel_sq_sum
=
0
total_mel_len
=
0
for
batch
in
tqdm
(
data_loader
,
leave
=
False
):
mels
=
batch
[
"y"
]
mel_lengths
=
batch
[
"y_lengths"
]
total_mel_len
+=
torch
.
sum
(
mel_lengths
)
total_mel_sum
+=
torch
.
sum
(
mels
)
total_mel_sq_sum
+=
torch
.
sum
(
torch
.
pow
(
mels
,
2
))
data_mean
=
total_mel_sum
/
(
total_mel_len
*
out_channels
)
data_std
=
torch
.
sqrt
((
total_mel_sq_sum
/
(
total_mel_len
*
out_channels
))
-
torch
.
pow
(
data_mean
,
2
))
return
{
"mel_mean"
:
data_mean
.
item
(),
"mel_std"
:
data_std
.
item
()}
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-config"
,
type
=
str
,
default
=
"vctk.yaml"
,
help
=
"The name of the yaml config file under configs/data"
,
)
parser
.
add_argument
(
"-b"
,
"--batch-size"
,
type
=
int
,
default
=
"256"
,
help
=
"Can have increased batch size for faster computation"
,
)
parser
.
add_argument
(
"-f"
,
"--force"
,
action
=
"store_true"
,
default
=
False
,
required
=
False
,
help
=
"force overwrite the file"
,
)
args
=
parser
.
parse_args
()
output_file
=
Path
(
args
.
input_config
).
with_suffix
(
".json"
)
if
os
.
path
.
exists
(
output_file
)
and
not
args
.
force
:
print
(
"File already exists. Use -f to force overwrite"
)
sys
.
exit
(
1
)
with
initialize
(
version_base
=
"1.3"
,
config_path
=
"../../configs/data"
):
cfg
=
compose
(
config_name
=
args
.
input_config
,
return_hydra_config
=
True
,
overrides
=
[])
root_path
=
rootutils
.
find_root
(
search_from
=
__file__
,
indicator
=
".project-root"
)
with
open_dict
(
cfg
):
del
cfg
[
"hydra"
]
del
cfg
[
"_target_"
]
cfg
[
"data_statistics"
]
=
None
cfg
[
"seed"
]
=
1234
cfg
[
"batch_size"
]
=
args
.
batch_size
cfg
[
"train_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"train_filelist_path"
]))
cfg
[
"valid_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"valid_filelist_path"
]))
cfg
[
"load_durations"
]
=
False
text_mel_datamodule
=
TextMelDataModule
(
**
cfg
)
text_mel_datamodule
.
setup
()
data_loader
=
text_mel_datamodule
.
train_dataloader
()
log
.
info
(
"Dataloader loaded! Now computing stats..."
)
params
=
compute_data_statistics
(
data_loader
,
cfg
[
"n_feats"
])
print
(
params
)
with
open
(
output_file
,
"w"
,
encoding
=
"utf-8"
)
as
dumpfile
:
json
.
dump
(
params
,
dumpfile
)
if
__name__
==
"__main__"
:
main
()
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/get_durations_from_trained_model.py
0 → 100644
View file @
39ac40a9
r
"""
The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
when needed.
Parameters from hparam.py will be used
"""
import
argparse
import
json
import
os
import
sys
from
pathlib
import
Path
import
lightning
import
numpy
as
np
import
rootutils
import
torch
from
hydra
import
compose
,
initialize
from
omegaconf
import
open_dict
from
torch
import
nn
from
tqdm.auto
import
tqdm
from
matcha.cli
import
get_device
from
matcha.data.text_mel_datamodule
import
TextMelDataModule
from
matcha.models.matcha_tts
import
MatchaTTS
from
matcha.utils.logging_utils
import
pylogger
from
matcha.utils.utils
import
get_phoneme_durations
log
=
pylogger
.
get_pylogger
(
__name__
)
def
save_durations_to_folder
(
attn
:
torch
.
Tensor
,
x_length
:
int
,
y_length
:
int
,
filepath
:
str
,
output_folder
:
Path
,
text
:
str
):
durations
=
attn
.
squeeze
().
sum
(
1
)[:
x_length
].
numpy
()
durations_json
=
get_phoneme_durations
(
durations
,
text
)
output
=
output_folder
/
Path
(
filepath
).
name
.
replace
(
".wav"
,
".npy"
)
with
open
(
output
.
with_suffix
(
".json"
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
durations_json
,
f
,
indent
=
4
,
ensure_ascii
=
False
)
np
.
save
(
output
,
durations
)
@
torch
.
inference_mode
()
def
compute_durations
(
data_loader
:
torch
.
utils
.
data
.
DataLoader
,
model
:
nn
.
Module
,
device
:
torch
.
device
,
output_folder
):
"""Generate durations from the model for each datapoint and save it in a folder
Args:
data_loader (torch.utils.data.DataLoader): Dataloader
model (nn.Module): MatchaTTS model
device (torch.device): GPU or CPU
"""
for
batch
in
tqdm
(
data_loader
,
desc
=
"🍵 Computing durations 🍵:"
):
x
,
x_lengths
=
batch
[
"x"
],
batch
[
"x_lengths"
]
y
,
y_lengths
=
batch
[
"y"
],
batch
[
"y_lengths"
]
spks
=
batch
[
"spks"
]
x
=
x
.
to
(
device
)
y
=
y
.
to
(
device
)
x_lengths
=
x_lengths
.
to
(
device
)
y_lengths
=
y_lengths
.
to
(
device
)
spks
=
spks
.
to
(
device
)
if
spks
is
not
None
else
None
_
,
_
,
_
,
attn
=
model
(
x
=
x
,
x_lengths
=
x_lengths
,
y
=
y
,
y_lengths
=
y_lengths
,
spks
=
spks
,
)
attn
=
attn
.
cpu
()
for
i
in
range
(
attn
.
shape
[
0
]):
save_durations_to_folder
(
attn
[
i
],
x_lengths
[
i
].
item
(),
y_lengths
[
i
].
item
(),
batch
[
"filepaths"
][
i
],
output_folder
,
batch
[
"x_texts"
][
i
],
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--input-config"
,
type
=
str
,
default
=
"ljspeech.yaml"
,
help
=
"The name of the yaml config file under configs/data"
,
)
parser
.
add_argument
(
"-b"
,
"--batch-size"
,
type
=
int
,
default
=
"32"
,
help
=
"Can have increased batch size for faster computation"
,
)
parser
.
add_argument
(
"-f"
,
"--force"
,
action
=
"store_true"
,
default
=
False
,
required
=
False
,
help
=
"force overwrite the file"
,
)
parser
.
add_argument
(
"-c"
,
"--checkpoint_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to the checkpoint file to load the model from"
,
)
parser
.
add_argument
(
"-o"
,
"--output-folder"
,
type
=
str
,
default
=
None
,
help
=
"Output folder to save the data statistics"
,
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
help
=
"Use CPU for inference, not recommended (default: use GPU if available)"
)
args
=
parser
.
parse_args
()
with
initialize
(
version_base
=
"1.3"
,
config_path
=
"../../configs/data"
):
cfg
=
compose
(
config_name
=
args
.
input_config
,
return_hydra_config
=
True
,
overrides
=
[])
root_path
=
rootutils
.
find_root
(
search_from
=
__file__
,
indicator
=
".project-root"
)
with
open_dict
(
cfg
):
del
cfg
[
"hydra"
]
del
cfg
[
"_target_"
]
cfg
[
"seed"
]
=
1234
cfg
[
"batch_size"
]
=
args
.
batch_size
cfg
[
"train_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"train_filelist_path"
]))
cfg
[
"valid_filelist_path"
]
=
str
(
os
.
path
.
join
(
root_path
,
cfg
[
"valid_filelist_path"
]))
cfg
[
"load_durations"
]
=
False
if
args
.
output_folder
is
not
None
:
output_folder
=
Path
(
args
.
output_folder
)
else
:
output_folder
=
Path
(
cfg
[
"train_filelist_path"
]).
parent
/
"durations"
print
(
f
"Output folder set to:
{
output_folder
}
"
)
if
os
.
path
.
exists
(
output_folder
)
and
not
args
.
force
:
print
(
"Folder already exists. Use -f to force overwrite"
)
sys
.
exit
(
1
)
output_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
print
(
f
"Preprocessing:
{
cfg
[
'name'
]
}
from training filelist:
{
cfg
[
'train_filelist_path'
]
}
"
)
print
(
"Loading model..."
)
device
=
get_device
(
args
)
model
=
MatchaTTS
.
load_from_checkpoint
(
args
.
checkpoint_path
,
map_location
=
device
)
text_mel_datamodule
=
TextMelDataModule
(
**
cfg
)
text_mel_datamodule
.
setup
()
try
:
print
(
"Computing stats for training set if exists..."
)
train_dataloader
=
text_mel_datamodule
.
train_dataloader
()
compute_durations
(
train_dataloader
,
model
,
device
,
output_folder
)
except
lightning
.
fabric
.
utilities
.
exceptions
.
MisconfigurationException
:
print
(
"No training set found"
)
try
:
print
(
"Computing stats for validation set if exists..."
)
val_dataloader
=
text_mel_datamodule
.
val_dataloader
()
compute_durations
(
val_dataloader
,
model
,
device
,
output_folder
)
except
lightning
.
fabric
.
utilities
.
exceptions
.
MisconfigurationException
:
print
(
"No validation set found"
)
try
:
print
(
"Computing stats for test set if exists..."
)
test_dataloader
=
text_mel_datamodule
.
test_dataloader
()
compute_durations
(
test_dataloader
,
model
,
device
,
output_folder
)
except
lightning
.
fabric
.
utilities
.
exceptions
.
MisconfigurationException
:
print
(
"No test set found"
)
print
(
f
"[+] Done! Data statistics saved to:
{
output_folder
}
"
)
if
__name__
==
"__main__"
:
# Helps with generating durations for the dataset to train other architectures
# that cannot learn to align due to limited size of dataset
# Example usage:
# python python matcha/utils/get_durations_from_trained_model.py -i ljspeech.yaml -c pretrained_model
# This will create a folder in data/processed_data/durations/ljspeech with the durations
main
()
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/instantiators.py
0 → 100644
View file @
39ac40a9
from
typing
import
List
import
hydra
from
lightning
import
Callback
from
lightning.pytorch.loggers
import
Logger
from
omegaconf
import
DictConfig
from
matcha.utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
def
instantiate_callbacks
(
callbacks_cfg
:
DictConfig
)
->
List
[
Callback
]:
"""Instantiates callbacks from config.
:param callbacks_cfg: A DictConfig object containing callback configurations.
:return: A list of instantiated callbacks.
"""
callbacks
:
List
[
Callback
]
=
[]
if
not
callbacks_cfg
:
log
.
warning
(
"No callback configs found! Skipping.."
)
return
callbacks
if
not
isinstance
(
callbacks_cfg
,
DictConfig
):
raise
TypeError
(
"Callbacks config must be a DictConfig!"
)
for
_
,
cb_conf
in
callbacks_cfg
.
items
():
if
isinstance
(
cb_conf
,
DictConfig
)
and
"_target_"
in
cb_conf
:
log
.
info
(
f
"Instantiating callback <
{
cb_conf
.
_target_
}
>"
)
# pylint: disable=protected-access
callbacks
.
append
(
hydra
.
utils
.
instantiate
(
cb_conf
))
return
callbacks
def
instantiate_loggers
(
logger_cfg
:
DictConfig
)
->
List
[
Logger
]:
"""Instantiates loggers from config.
:param logger_cfg: A DictConfig object containing logger configurations.
:return: A list of instantiated loggers.
"""
logger
:
List
[
Logger
]
=
[]
if
not
logger_cfg
:
log
.
warning
(
"No logger configs found! Skipping..."
)
return
logger
if
not
isinstance
(
logger_cfg
,
DictConfig
):
raise
TypeError
(
"Logger config must be a DictConfig!"
)
for
_
,
lg_conf
in
logger_cfg
.
items
():
if
isinstance
(
lg_conf
,
DictConfig
)
and
"_target_"
in
lg_conf
:
log
.
info
(
f
"Instantiating logger <
{
lg_conf
.
_target_
}
>"
)
# pylint: disable=protected-access
logger
.
append
(
hydra
.
utils
.
instantiate
(
lg_conf
))
return
logger
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/logging_utils.py
0 → 100644
View file @
39ac40a9
from
typing
import
Any
,
Dict
from
lightning.pytorch.utilities
import
rank_zero_only
from
omegaconf
import
OmegaConf
from
matcha.utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
@
rank_zero_only
def
log_hyperparameters
(
object_dict
:
Dict
[
str
,
Any
])
->
None
:
"""Controls which config parts are saved by Lightning loggers.
Additionally saves:
- Number of model parameters
:param object_dict: A dictionary containing the following objects:
- `"cfg"`: A DictConfig object containing the main config.
- `"model"`: The Lightning model.
- `"trainer"`: The Lightning trainer.
"""
hparams
=
{}
cfg
=
OmegaConf
.
to_container
(
object_dict
[
"cfg"
])
model
=
object_dict
[
"model"
]
trainer
=
object_dict
[
"trainer"
]
if
not
trainer
.
logger
:
log
.
warning
(
"Logger not found! Skipping hyperparameter logging..."
)
return
hparams
[
"model"
]
=
cfg
[
"model"
]
# save number of model parameters
hparams
[
"model/params/total"
]
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
hparams
[
"model/params/trainable"
]
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
hparams
[
"model/params/non_trainable"
]
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
not
p
.
requires_grad
)
hparams
[
"data"
]
=
cfg
[
"data"
]
hparams
[
"trainer"
]
=
cfg
[
"trainer"
]
hparams
[
"callbacks"
]
=
cfg
.
get
(
"callbacks"
)
hparams
[
"extras"
]
=
cfg
.
get
(
"extras"
)
hparams
[
"task_name"
]
=
cfg
.
get
(
"task_name"
)
hparams
[
"tags"
]
=
cfg
.
get
(
"tags"
)
hparams
[
"ckpt_path"
]
=
cfg
.
get
(
"ckpt_path"
)
hparams
[
"seed"
]
=
cfg
.
get
(
"seed"
)
# send hparams to all loggers
for
logger
in
trainer
.
loggers
:
logger
.
log_hyperparams
(
hparams
)
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/model.py
0 → 100644
View file @
39ac40a9
""" from https://github.com/jaywalnut310/glow-tts """
import
numpy
as
np
import
torch
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
max_length
=
length
.
max
()
x
=
torch
.
arange
(
max_length
,
dtype
=
length
.
dtype
,
device
=
length
.
device
)
return
x
.
unsqueeze
(
0
)
<
length
.
unsqueeze
(
1
)
def
fix_len_compatibility
(
length
,
num_downsamplings_in_unet
=
2
):
factor
=
torch
.
scalar_tensor
(
2
).
pow
(
num_downsamplings_in_unet
)
length
=
(
length
/
factor
).
ceil
()
*
factor
if
not
torch
.
onnx
.
is_in_onnx_export
():
return
length
.
int
().
item
()
else
:
return
length
def
convert_pad_shape
(
pad_shape
):
inverted_shape
=
pad_shape
[::
-
1
]
pad_shape
=
[
item
for
sublist
in
inverted_shape
for
item
in
sublist
]
return
pad_shape
def
generate_path
(
duration
,
mask
):
device
=
duration
.
device
b
,
t_x
,
t_y
=
mask
.
shape
cum_duration
=
torch
.
cumsum
(
duration
,
1
)
path
=
torch
.
zeros
(
b
,
t_x
,
t_y
,
dtype
=
mask
.
dtype
).
to
(
device
=
device
)
cum_duration_flat
=
cum_duration
.
view
(
b
*
t_x
)
path
=
sequence_mask
(
cum_duration_flat
,
t_y
).
to
(
mask
.
dtype
)
path
=
path
.
view
(
b
,
t_x
,
t_y
)
path
=
path
-
torch
.
nn
.
functional
.
pad
(
path
,
convert_pad_shape
([[
0
,
0
],
[
1
,
0
],
[
0
,
0
]]))[:,
:
-
1
]
path
=
path
*
mask
return
path
def
duration_loss
(
logw
,
logw_
,
lengths
):
loss
=
torch
.
sum
((
logw
-
logw_
)
**
2
)
/
torch
.
sum
(
lengths
)
return
loss
def
normalize
(
data
,
mu
,
std
):
if
not
isinstance
(
mu
,
(
float
,
int
)):
if
isinstance
(
mu
,
list
):
mu
=
torch
.
tensor
(
mu
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
mu
,
torch
.
Tensor
):
mu
=
mu
.
to
(
data
.
device
)
elif
isinstance
(
mu
,
np
.
ndarray
):
mu
=
torch
.
from_numpy
(
mu
).
to
(
data
.
device
)
mu
=
mu
.
unsqueeze
(
-
1
)
if
not
isinstance
(
std
,
(
float
,
int
)):
if
isinstance
(
std
,
list
):
std
=
torch
.
tensor
(
std
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
std
,
torch
.
Tensor
):
std
=
std
.
to
(
data
.
device
)
elif
isinstance
(
std
,
np
.
ndarray
):
std
=
torch
.
from_numpy
(
std
).
to
(
data
.
device
)
std
=
std
.
unsqueeze
(
-
1
)
return
(
data
-
mu
)
/
std
def
denormalize
(
data
,
mu
,
std
):
if
not
isinstance
(
mu
,
float
):
if
isinstance
(
mu
,
list
):
mu
=
torch
.
tensor
(
mu
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
mu
,
torch
.
Tensor
):
mu
=
mu
.
to
(
data
.
device
)
elif
isinstance
(
mu
,
np
.
ndarray
):
mu
=
torch
.
from_numpy
(
mu
).
to
(
data
.
device
)
mu
=
mu
.
unsqueeze
(
-
1
)
if
not
isinstance
(
std
,
float
):
if
isinstance
(
std
,
list
):
std
=
torch
.
tensor
(
std
,
dtype
=
data
.
dtype
,
device
=
data
.
device
)
elif
isinstance
(
std
,
torch
.
Tensor
):
std
=
std
.
to
(
data
.
device
)
elif
isinstance
(
std
,
np
.
ndarray
):
std
=
torch
.
from_numpy
(
std
).
to
(
data
.
device
)
std
=
std
.
unsqueeze
(
-
1
)
return
data
*
std
+
mu
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py
0 → 100644
View file @
39ac40a9
import
numpy
as
np
import
torch
from
matcha.utils.monotonic_align.core
import
maximum_path_c
def
maximum_path
(
value
,
mask
):
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value
=
value
*
mask
device
=
value
.
device
dtype
=
value
.
dtype
value
=
value
.
data
.
cpu
().
numpy
().
astype
(
np
.
float32
)
path
=
np
.
zeros_like
(
value
).
astype
(
np
.
int32
)
mask
=
mask
.
data
.
cpu
().
numpy
()
t_x_max
=
mask
.
sum
(
1
)[:,
0
].
astype
(
np
.
int32
)
t_y_max
=
mask
.
sum
(
2
)[:,
0
].
astype
(
np
.
int32
)
maximum_path_c
(
path
,
value
,
t_x_max
,
t_y_max
)
return
torch
.
from_numpy
(
path
).
to
(
device
=
device
,
dtype
=
dtype
)
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx
0 → 100644
View file @
39ac40a9
import
numpy
as
np
cimport
cython
cimport
numpy
as
np
from
cython.parallel
import
prange
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
cdef
void
maximum_path_each
(
int
[:,::
1
]
path
,
float
[:,::
1
]
value
,
int
t_x
,
int
t_y
,
float
max_neg_val
)
nogil
:
cdef
int
x
cdef
int
y
cdef
float
v_prev
cdef
float
v_cur
cdef
float
tmp
cdef
int
index
=
t_x
-
1
for
y
in
range
(
t_y
):
for
x
in
range
(
max
(
0
,
t_x
+
y
-
t_y
),
min
(
t_x
,
y
+
1
)):
if
x
==
y
:
v_cur
=
max_neg_val
else
:
v_cur
=
value
[
x
,
y
-
1
]
if
x
==
0
:
if
y
==
0
:
v_prev
=
0.
else
:
v_prev
=
max_neg_val
else
:
v_prev
=
value
[
x
-
1
,
y
-
1
]
value
[
x
,
y
]
=
max
(
v_cur
,
v_prev
)
+
value
[
x
,
y
]
for
y
in
range
(
t_y
-
1
,
-
1
,
-
1
):
path
[
index
,
y
]
=
1
if
index
!=
0
and
(
index
==
y
or
value
[
index
,
y
-
1
]
<
value
[
index
-
1
,
y
-
1
]):
index
=
index
-
1
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
cpdef
void
maximum_path_c
(
int
[:,:,::
1
]
paths
,
float
[:,:,::
1
]
values
,
int
[::
1
]
t_xs
,
int
[::
1
]
t_ys
,
float
max_neg_val
=-
1e9
)
nogil
:
cdef
int
b
=
values
.
shape
[
0
]
cdef
int
i
for
i
in
prange
(
b
,
nogil
=
True
):
maximum_path_each
(
paths
[
i
],
values
[
i
],
t_xs
[
i
],
t_ys
[
i
],
max_neg_val
)
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py
0 → 100644
View file @
39ac40a9
# from distutils.core import setup
# from Cython.Build import cythonize
# import numpy
# setup(name='monotonic_align',
# ext_modules=cythonize("core.pyx"),
# include_dirs=[numpy.get_include()])
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/pylogger.py
0 → 100644
View file @
39ac40a9
import
logging
from
lightning.pytorch.utilities
import
rank_zero_only
def
get_pylogger
(
name
:
str
=
__name__
)
->
logging
.
Logger
:
"""Initializes a multi-GPU-friendly python command line logger.
:param name: The name of the logger, defaults to ``__name__``.
:return: A logger object.
"""
logger
=
logging
.
getLogger
(
name
)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
logging_levels
=
(
"debug"
,
"info"
,
"warning"
,
"error"
,
"exception"
,
"fatal"
,
"critical"
)
for
level
in
logging_levels
:
setattr
(
logger
,
level
,
rank_zero_only
(
getattr
(
logger
,
level
)))
return
logger
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/rich_utils.py
0 → 100644
View file @
39ac40a9
from
pathlib
import
Path
from
typing
import
Sequence
import
rich
import
rich.syntax
import
rich.tree
from
hydra.core.hydra_config
import
HydraConfig
from
lightning.pytorch.utilities
import
rank_zero_only
from
omegaconf
import
DictConfig
,
OmegaConf
,
open_dict
from
rich.prompt
import
Prompt
from
matcha.utils
import
pylogger
log
=
pylogger
.
get_pylogger
(
__name__
)
@
rank_zero_only
def
print_config_tree
(
cfg
:
DictConfig
,
print_order
:
Sequence
[
str
]
=
(
"data"
,
"model"
,
"callbacks"
,
"logger"
,
"trainer"
,
"paths"
,
"extras"
,
),
resolve
:
bool
=
False
,
save_to_file
:
bool
=
False
,
)
->
None
:
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
:param cfg: A DictConfig composed by Hydra.
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
"callbacks", "logger", "trainer", "paths", "extras")``.
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
"""
style
=
"dim"
tree
=
rich
.
tree
.
Tree
(
"CONFIG"
,
style
=
style
,
guide_style
=
style
)
queue
=
[]
# add fields from `print_order` to queue
for
field
in
print_order
:
_
=
(
queue
.
append
(
field
)
if
field
in
cfg
else
log
.
warning
(
f
"Field '
{
field
}
' not found in config. Skipping '
{
field
}
' config printing..."
)
)
# add all the other fields to queue (not specified in `print_order`)
for
field
in
cfg
:
if
field
not
in
queue
:
queue
.
append
(
field
)
# generate config tree from queue
for
field
in
queue
:
branch
=
tree
.
add
(
field
,
style
=
style
,
guide_style
=
style
)
config_group
=
cfg
[
field
]
if
isinstance
(
config_group
,
DictConfig
):
branch_content
=
OmegaConf
.
to_yaml
(
config_group
,
resolve
=
resolve
)
else
:
branch_content
=
str
(
config_group
)
branch
.
add
(
rich
.
syntax
.
Syntax
(
branch_content
,
"yaml"
))
# print config tree
rich
.
print
(
tree
)
# save config tree to file
if
save_to_file
:
with
open
(
Path
(
cfg
.
paths
.
output_dir
,
"config_tree.log"
),
"w"
,
encoding
=
"utf-8"
)
as
file
:
rich
.
print
(
tree
,
file
=
file
)
@
rank_zero_only
def
enforce_tags
(
cfg
:
DictConfig
,
save_to_file
:
bool
=
False
)
->
None
:
"""Prompts user to input tags from command line if no tags are provided in config.
:param cfg: A DictConfig composed by Hydra.
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
"""
if
not
cfg
.
get
(
"tags"
):
if
"id"
in
HydraConfig
().
cfg
.
hydra
.
job
:
raise
ValueError
(
"Specify tags before launching a multirun!"
)
log
.
warning
(
"No tags provided in config. Prompting user to input tags..."
)
tags
=
Prompt
.
ask
(
"Enter a list of comma separated tags"
,
default
=
"dev"
)
tags
=
[
t
.
strip
()
for
t
in
tags
.
split
(
","
)
if
t
!=
""
]
with
open_dict
(
cfg
):
cfg
.
tags
=
tags
log
.
info
(
f
"Tags:
{
cfg
.
tags
}
"
)
if
save_to_file
:
with
open
(
Path
(
cfg
.
paths
.
output_dir
,
"tags.log"
),
"w"
,
encoding
=
"utf-8"
)
as
file
:
rich
.
print
(
cfg
.
tags
,
file
=
file
)
third_party/GLM-4-Voice/third_party/Matcha-TTS/matcha/utils/utils.py
0 → 100644
View file @
39ac40a9
import
os
import
sys
import
warnings
from
importlib.util
import
find_spec
from
math
import
ceil
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
import
gdown
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
torch
import
wget
from
omegaconf
import
DictConfig
from
matcha.utils
import
pylogger
,
rich_utils
log
=
pylogger
.
get_pylogger
(
__name__
)
def
extras
(
cfg
:
DictConfig
)
->
None
:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
:param cfg: A DictConfig object containing the config tree.
"""
# return if no `extras` config
if
not
cfg
.
get
(
"extras"
):
log
.
warning
(
"Extras config not found! <cfg.extras=null>"
)
return
# disable python warnings
if
cfg
.
extras
.
get
(
"ignore_warnings"
):
log
.
info
(
"Disabling python warnings! <cfg.extras.ignore_warnings=True>"
)
warnings
.
filterwarnings
(
"ignore"
)
# prompt user to input tags from command line if none are provided in the config
if
cfg
.
extras
.
get
(
"enforce_tags"
):
log
.
info
(
"Enforcing tags! <cfg.extras.enforce_tags=True>"
)
rich_utils
.
enforce_tags
(
cfg
,
save_to_file
=
True
)
# pretty print config tree using Rich library
if
cfg
.
extras
.
get
(
"print_config"
):
log
.
info
(
"Printing config tree with Rich! <cfg.extras.print_config=True>"
)
rich_utils
.
print_config_tree
(
cfg
,
resolve
=
True
,
save_to_file
=
True
)
def
task_wrapper
(
task_func
:
Callable
)
->
Callable
:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
```
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
```
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def
wrap
(
cfg
:
DictConfig
)
->
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
# execute the task
try
:
metric_dict
,
object_dict
=
task_func
(
cfg
=
cfg
)
# things to do if exception occurs
except
Exception
as
ex
:
# save exception to `.log` file
log
.
exception
(
""
)
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise
ex
# things to always do after either success or exception
finally
:
# display output dir path in terminal
log
.
info
(
f
"Output dir:
{
cfg
.
paths
.
output_dir
}
"
)
# always close wandb run (even if exception occurs so multirun won't fail)
if
find_spec
(
"wandb"
):
# check if wandb is installed
import
wandb
if
wandb
.
run
:
log
.
info
(
"Closing wandb!"
)
wandb
.
finish
()
return
metric_dict
,
object_dict
return
wrap
def
get_metric_value
(
metric_dict
:
Dict
[
str
,
Any
],
metric_name
:
str
)
->
float
:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: The name of the metric to retrieve.
:return: The value of the metric.
"""
if
not
metric_name
:
log
.
info
(
"Metric name is None! Skipping metric value retrieval..."
)
return
None
if
metric_name
not
in
metric_dict
:
raise
ValueError
(
f
"Metric value not found! <metric_name=
{
metric_name
}
>
\n
"
"Make sure metric name logged in LightningModule is correct!
\n
"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value
=
metric_dict
[
metric_name
].
item
()
log
.
info
(
f
"Retrieved metric value! <
{
metric_name
}
=
{
metric_value
}
>"
)
return
metric_value
def
intersperse
(
lst
,
item
):
# Adds blank symbol
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
result
[
1
::
2
]
=
lst
return
result
def
save_figure_to_numpy
(
fig
):
data
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
np
.
uint8
,
sep
=
""
)
data
=
data
.
reshape
(
fig
.
canvas
.
get_width_height
()[::
-
1
]
+
(
3
,))
return
data
def
plot_tensor
(
tensor
):
plt
.
style
.
use
(
"default"
)
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
12
,
3
))
im
=
ax
.
imshow
(
tensor
,
aspect
=
"auto"
,
origin
=
"lower"
,
interpolation
=
"none"
)
plt
.
colorbar
(
im
,
ax
=
ax
)
plt
.
tight_layout
()
fig
.
canvas
.
draw
()
data
=
save_figure_to_numpy
(
fig
)
plt
.
close
()
return
data
def
save_plot
(
tensor
,
savepath
):
plt
.
style
.
use
(
"default"
)
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
12
,
3
))
im
=
ax
.
imshow
(
tensor
,
aspect
=
"auto"
,
origin
=
"lower"
,
interpolation
=
"none"
)
plt
.
colorbar
(
im
,
ax
=
ax
)
plt
.
tight_layout
()
fig
.
canvas
.
draw
()
plt
.
savefig
(
savepath
)
plt
.
close
()
def
to_numpy
(
tensor
):
if
isinstance
(
tensor
,
np
.
ndarray
):
return
tensor
elif
isinstance
(
tensor
,
torch
.
Tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
elif
isinstance
(
tensor
,
list
):
return
np
.
array
(
tensor
)
else
:
raise
TypeError
(
"Unsupported type for conversion to numpy array"
)
def
get_user_data_dir
(
appname
=
"matcha_tts"
):
"""
Args:
appname (str): Name of application
Returns:
Path: path to user data directory
"""
MATCHA_HOME
=
os
.
environ
.
get
(
"MATCHA_HOME"
)
if
MATCHA_HOME
is
not
None
:
ans
=
Path
(
MATCHA_HOME
).
expanduser
().
resolve
(
strict
=
False
)
elif
sys
.
platform
==
"win32"
:
import
winreg
# pylint: disable=import-outside-toplevel
key
=
winreg
.
OpenKey
(
winreg
.
HKEY_CURRENT_USER
,
r
"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
,
)
dir_
,
_
=
winreg
.
QueryValueEx
(
key
,
"Local AppData"
)
ans
=
Path
(
dir_
).
resolve
(
strict
=
False
)
elif
sys
.
platform
==
"darwin"
:
ans
=
Path
(
"~/Library/Application Support/"
).
expanduser
()
else
:
ans
=
Path
.
home
().
joinpath
(
".local/share"
)
final_path
=
ans
.
joinpath
(
appname
)
final_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
return
final_path
def
assert_model_downloaded
(
checkpoint_path
,
url
,
use_wget
=
True
):
if
Path
(
checkpoint_path
).
exists
():
log
.
debug
(
f
"[+] Model already present at
{
checkpoint_path
}
!"
)
print
(
f
"[+] Model already present at
{
checkpoint_path
}
!"
)
return
log
.
info
(
f
"[-] Model not found at
{
checkpoint_path
}
! Will download it"
)
print
(
f
"[-] Model not found at
{
checkpoint_path
}
! Will download it"
)
checkpoint_path
=
str
(
checkpoint_path
)
if
not
use_wget
:
gdown
.
download
(
url
=
url
,
output
=
checkpoint_path
,
quiet
=
False
,
fuzzy
=
True
)
else
:
wget
.
download
(
url
=
url
,
out
=
checkpoint_path
)
def
get_phoneme_durations
(
durations
,
phones
):
prev
=
durations
[
0
]
merged_durations
=
[]
# Convolve with stride 2
for
i
in
range
(
1
,
len
(
durations
),
2
):
if
i
==
len
(
durations
)
-
2
:
# if it is last take full value
next_half
=
durations
[
i
+
1
]
else
:
next_half
=
ceil
(
durations
[
i
+
1
]
/
2
)
curr
=
prev
+
durations
[
i
]
+
next_half
prev
=
durations
[
i
+
1
]
-
next_half
merged_durations
.
append
(
curr
)
assert
len
(
phones
)
==
len
(
merged_durations
)
assert
len
(
merged_durations
)
==
(
len
(
durations
)
-
1
)
//
2
merged_durations
=
torch
.
cumsum
(
torch
.
tensor
(
merged_durations
),
0
,
dtype
=
torch
.
long
)
start
=
torch
.
tensor
(
0
)
duration_json
=
[]
for
i
,
duration
in
enumerate
(
merged_durations
):
duration_json
.
append
(
{
phones
[
i
]:
{
"starttime"
:
start
.
item
(),
"endtime"
:
duration
.
item
(),
"duration"
:
duration
.
item
()
-
start
.
item
(),
}
}
)
start
=
duration
assert
list
(
duration_json
[
-
1
].
values
())[
0
][
"endtime"
]
==
sum
(
durations
),
f
"
{
list
(
duration_json
[
-
1
].
values
())[
0
][
'endtime'
],
sum
(
durations
)
}
"
return
duration_json
Prev
1
…
9
10
11
12
13
14
15
16
17
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment