Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ace07110
Commit
ace07110
authored
Jun 16, 2022
by
patil-suraj
Browse files
style
parent
988369a0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
198 additions
and
115 deletions
+198
-115
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-2
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+170
-92
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+24
-20
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
No files found.
src/diffusers/models/unet_grad_tts.py
View file @
ace07110
...
...
@@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
n_spks
>
1
:
self
.
spk_emb
=
torch
.
nn
.
Embedding
(
n_spks
,
spk_emb_dim
)
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
))
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
...
...
src/diffusers/pipelines/grad_tts_utils.py
View file @
ace07110
# tokenizer
import
re
import
os
import
re
from
shutil
import
copyfile
import
torch
try
:
from
transformers
import
PreTrainedTokenizer
except
:
...
...
@@ -25,17 +26,95 @@ except:
valid_symbols
=
[
'AA'
,
'AA0'
,
'AA1'
,
'AA2'
,
'AE'
,
'AE0'
,
'AE1'
,
'AE2'
,
'AH'
,
'AH0'
,
'AH1'
,
'AH2'
,
'AO'
,
'AO0'
,
'AO1'
,
'AO2'
,
'AW'
,
'AW0'
,
'AW1'
,
'AW2'
,
'AY'
,
'AY0'
,
'AY1'
,
'AY2'
,
'B'
,
'CH'
,
'D'
,
'DH'
,
'EH'
,
'EH0'
,
'EH1'
,
'EH2'
,
'ER'
,
'ER0'
,
'ER1'
,
'ER2'
,
'EY'
,
'EY0'
,
'EY1'
,
'EY2'
,
'F'
,
'G'
,
'HH'
,
'IH'
,
'IH0'
,
'IH1'
,
'IH2'
,
'IY'
,
'IY0'
,
'IY1'
,
'IY2'
,
'JH'
,
'K'
,
'L'
,
'M'
,
'N'
,
'NG'
,
'OW'
,
'OW0'
,
'OW1'
,
'OW2'
,
'OY'
,
'OY0'
,
'OY1'
,
'OY2'
,
'P'
,
'R'
,
'S'
,
'SH'
,
'T'
,
'TH'
,
'UH'
,
'UH0'
,
'UH1'
,
'UH2'
,
'UW'
,
'UW0'
,
'UW1'
,
'UW2'
,
'V'
,
'W'
,
'Y'
,
'Z'
,
'ZH'
"AA"
,
"AA0"
,
"AA1"
,
"AA2"
,
"AE"
,
"AE0"
,
"AE1"
,
"AE2"
,
"AH"
,
"AH0"
,
"AH1"
,
"AH2"
,
"AO"
,
"AO0"
,
"AO1"
,
"AO2"
,
"AW"
,
"AW0"
,
"AW1"
,
"AW2"
,
"AY"
,
"AY0"
,
"AY1"
,
"AY2"
,
"B"
,
"CH"
,
"D"
,
"DH"
,
"EH"
,
"EH0"
,
"EH1"
,
"EH2"
,
"ER"
,
"ER0"
,
"ER1"
,
"ER2"
,
"EY"
,
"EY0"
,
"EY1"
,
"EY2"
,
"F"
,
"G"
,
"HH"
,
"IH"
,
"IH0"
,
"IH1"
,
"IH2"
,
"IY"
,
"IY0"
,
"IY1"
,
"IY2"
,
"JH"
,
"K"
,
"L"
,
"M"
,
"N"
,
"NG"
,
"OW"
,
"OW0"
,
"OW1"
,
"OW2"
,
"OY"
,
"OY0"
,
"OY1"
,
"OY2"
,
"P"
,
"R"
,
"S"
,
"SH"
,
"T"
,
"TH"
,
"UH"
,
"UH0"
,
"UH1"
,
"UH2"
,
"UW"
,
"UW0"
,
"UW1"
,
"UW2"
,
"V"
,
"W"
,
"Y"
,
"Z"
,
"ZH"
,
]
_valid_symbol_set
=
set
(
valid_symbols
)
def
intersperse
(
lst
,
item
):
# Adds blank symbol
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
...
...
@@ -46,7 +125,7 @@ def intersperse(lst, item):
class
CMUDict
:
def
__init__
(
self
,
file_or_path
,
keep_ambiguous
=
True
):
if
isinstance
(
file_or_path
,
str
):
with
open
(
file_or_path
,
encoding
=
'
latin-1
'
)
as
f
:
with
open
(
file_or_path
,
encoding
=
"
latin-1
"
)
as
f
:
entries
=
_parse_cmudict
(
f
)
else
:
entries
=
_parse_cmudict
(
file_or_path
)
...
...
@@ -61,15 +140,15 @@ class CMUDict:
return
self
.
_entries
.
get
(
word
.
upper
())
_alt_re
=
re
.
compile
(
r
'
\([0-9]+\)
'
)
_alt_re
=
re
.
compile
(
r
"
\([0-9]+\)
"
)
def
_parse_cmudict
(
file
):
cmudict
=
{}
for
line
in
file
:
if
len
(
line
)
and
(
line
[
0
]
>=
'A'
and
line
[
0
]
<=
'Z'
or
line
[
0
]
==
"'"
):
parts
=
line
.
split
(
'
'
)
word
=
re
.
sub
(
_alt_re
,
''
,
parts
[
0
])
if
len
(
line
)
and
(
line
[
0
]
>=
"A"
and
line
[
0
]
<=
"Z"
or
line
[
0
]
==
"'"
):
parts
=
line
.
split
(
"
"
)
word
=
re
.
sub
(
_alt_re
,
""
,
parts
[
0
])
pronunciation
=
_get_pronunciation
(
parts
[
1
])
if
pronunciation
:
if
word
in
cmudict
:
...
...
@@ -80,36 +159,38 @@ def _parse_cmudict(file):
def
_get_pronunciation
(
s
):
parts
=
s
.
strip
().
split
(
' '
)
parts
=
s
.
strip
().
split
(
" "
)
for
part
in
parts
:
if
part
not
in
_valid_symbol_set
:
return
None
return
' '
.
join
(
parts
)
_whitespace_re
=
re
.
compile
(
r
'\s+'
)
_abbreviations
=
[(
re
.
compile
(
'
\\
b%s
\\
.'
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
'mrs'
,
'misess'
),
(
'mr'
,
'mister'
),
(
'dr'
,
'doctor'
),
(
'st'
,
'saint'
),
(
'co'
,
'company'
),
(
'jr'
,
'junior'
),
(
'maj'
,
'major'
),
(
'gen'
,
'general'
),
(
'drs'
,
'doctors'
),
(
'rev'
,
'reverend'
),
(
'lt'
,
'lieutenant'
),
(
'hon'
,
'honorable'
),
(
'sgt'
,
'sergeant'
),
(
'capt'
,
'captain'
),
(
'esq'
,
'esquire'
),
(
'ltd'
,
'limited'
),
(
'col'
,
'colonel'
),
(
'ft'
,
'fort'
),
]]
return
" "
.
join
(
parts
)
_whitespace_re
=
re
.
compile
(
r
"\s+"
)
_abbreviations
=
[
(
re
.
compile
(
"
\\
b%s
\\
."
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
"mrs"
,
"misess"
),
(
"mr"
,
"mister"
),
(
"dr"
,
"doctor"
),
(
"st"
,
"saint"
),
(
"co"
,
"company"
),
(
"jr"
,
"junior"
),
(
"maj"
,
"major"
),
(
"gen"
,
"general"
),
(
"drs"
,
"doctors"
),
(
"rev"
,
"reverend"
),
(
"lt"
,
"lieutenant"
),
(
"hon"
,
"honorable"
),
(
"sgt"
,
"sergeant"
),
(
"capt"
,
"captain"
),
(
"esq"
,
"esquire"
),
(
"ltd"
,
"limited"
),
(
"col"
,
"colonel"
),
(
"ft"
,
"fort"
),
]
]
def
expand_abbreviations
(
text
):
...
...
@@ -127,7 +208,7 @@ def lowercase(text):
def
collapse_whitespace
(
text
):
return
re
.
sub
(
_whitespace_re
,
' '
,
text
)
return
re
.
sub
(
_whitespace_re
,
" "
,
text
)
def
convert_to_ascii
(
text
):
...
...
@@ -156,46 +237,42 @@ def english_cleaners(text):
return
text
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
'
([0-9][0-9\,]+[0-9])
'
)
_decimal_number_re
=
re
.
compile
(
r
'
([0-9]+\.[0-9]+)
'
)
_pounds_re
=
re
.
compile
(
r
'
£([0-9\,]*[0-9]+)
'
)
_dollars_re
=
re
.
compile
(
r
'
\$([0-9\.\,]*[0-9]+)
'
)
_ordinal_re
=
re
.
compile
(
r
'
[0-9]+(st|nd|rd|th)
'
)
_number_re
=
re
.
compile
(
r
'
[0-9]+
'
)
_comma_number_re
=
re
.
compile
(
r
"
([0-9][0-9\,]+[0-9])
"
)
_decimal_number_re
=
re
.
compile
(
r
"
([0-9]+\.[0-9]+)
"
)
_pounds_re
=
re
.
compile
(
r
"
£([0-9\,]*[0-9]+)
"
)
_dollars_re
=
re
.
compile
(
r
"
\$([0-9\.\,]*[0-9]+)
"
)
_ordinal_re
=
re
.
compile
(
r
"
[0-9]+(st|nd|rd|th)
"
)
_number_re
=
re
.
compile
(
r
"
[0-9]+
"
)
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
','
,
''
)
return
m
.
group
(
1
).
replace
(
","
,
""
)
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
'.'
,
'
point
'
)
return
m
.
group
(
1
).
replace
(
"."
,
"
point
"
)
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
'.'
)
parts
=
match
.
split
(
"."
)
if
len
(
parts
)
>
2
:
return
match
+
'
dollars
'
return
match
+
"
dollars
"
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
dollar_unit
=
'
dollar
'
if
dollars
==
1
else
'
dollars
'
cent_unit
=
'
cent
'
if
cents
==
1
else
'
cents
'
return
'
%s %s, %s %s
'
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
dollar_unit
=
"
dollar
"
if
dollars
==
1
else
"
dollars
"
cent_unit
=
"
cent
"
if
cents
==
1
else
"
cents
"
return
"
%s %s, %s %s
"
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
elif
dollars
:
dollar_unit
=
'
dollar
'
if
dollars
==
1
else
'
dollars
'
return
'
%s %s
'
%
(
dollars
,
dollar_unit
)
dollar_unit
=
"
dollar
"
if
dollars
==
1
else
"
dollars
"
return
"
%s %s
"
%
(
dollars
,
dollar_unit
)
elif
cents
:
cent_unit
=
'
cent
'
if
cents
==
1
else
'
cents
'
return
'
%s %s
'
%
(
cents
,
cent_unit
)
cent_unit
=
"
cent
"
if
cents
==
1
else
"
cents
"
return
"
%s %s
"
%
(
cents
,
cent_unit
)
else
:
return
'
zero dollars
'
return
"
zero dollars
"
def
_expand_ordinal
(
m
):
...
...
@@ -206,37 +283,37 @@ def _expand_number(m):
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
return
'
two thousand
'
return
"
two thousand
"
elif
num
>
2000
and
num
<
2010
:
return
'
two thousand
'
+
_inflect
.
number_to_words
(
num
%
100
)
return
"
two thousand
"
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
'
hundred
'
return
_inflect
.
number_to_words
(
num
//
100
)
+
"
hundred
"
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
''
,
zero
=
'oh'
,
group
=
2
).
replace
(
', '
,
' '
)
return
_inflect
.
number_to_words
(
num
,
andword
=
""
,
zero
=
"oh"
,
group
=
2
).
replace
(
", "
,
" "
)
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
''
)
return
_inflect
.
number_to_words
(
num
,
andword
=
""
)
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
'
\1 pounds
'
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
"
\1 pounds
"
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
""" from https://github.com/keithito/tacotron """
_pad
=
'_'
_punctuation
=
'!
\
'
(),.:;?
'
_special
=
'-'
_letters
=
'
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
'
_pad
=
"_"
_punctuation
=
"!
'(),.:;?
"
_special
=
"-"
_letters
=
"
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
"
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet
=
[
'@'
+
s
for
s
in
valid_symbols
]
_arpabet
=
[
"@"
+
s
for
s
in
valid_symbols
]
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_special
)
+
list
(
_punctuation
)
+
list
(
_letters
)
+
_arpabet
...
...
@@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
_curly_re
=
re
.
compile
(
r
'
(.*?)\{(.+?)\}(.*)
'
)
_curly_re
=
re
.
compile
(
r
"
(.*?)\{(.+?)\}(.*)
"
)
def
get_arpabet
(
word
,
dictionary
):
...
...
@@ -257,7 +334,7 @@ def get_arpabet(word, dictionary):
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
'''
Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
"""
Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
...
...
@@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
Returns:
List of integers corresponding to the symbols in the text
'''
"""
sequence
=
[]
space
=
_symbols_to_sequence
(
' '
)
space
=
_symbols_to_sequence
(
" "
)
# Check for curly braces and treat their contents as ARPAbet:
while
len
(
text
):
m
=
_curly_re
.
match
(
text
)
...
...
@@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
def
sequence_to_text
(
sequence
):
'''
Converts a sequence of IDs back to a string
'''
result
=
''
"""
Converts a sequence of IDs back to a string
"""
result
=
""
for
symbol_id
in
sequence
:
if
symbol_id
in
_id_to_symbol
:
s
=
_id_to_symbol
[
symbol_id
]
# Enclose ARPAbet back in curly braces:
if
len
(
s
)
>
1
and
s
[
0
]
==
'@'
:
s
=
'
{%s}
'
%
s
[
1
:]
if
len
(
s
)
>
1
and
s
[
0
]
==
"@"
:
s
=
"
{%s}
"
%
s
[
1
:]
result
+=
s
return
result
.
replace
(
'
}{
'
,
' '
)
return
result
.
replace
(
"
}{
"
,
" "
)
def
_clean_text
(
text
,
cleaner_names
):
...
...
@@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols):
def
_arpabet_to_sequence
(
text
):
return
_symbols_to_sequence
([
'@'
+
s
for
s
in
text
.
split
()])
return
_symbols_to_sequence
([
"@"
+
s
for
s
in
text
.
split
()])
def
_should_keep_symbol
(
s
):
return
s
in
_symbol_to_id
and
s
!=
'_'
and
s
!=
'~'
return
s
in
_symbol_to_id
and
s
!=
"_"
and
s
!=
"~"
VOCAB_FILES_NAMES
=
{
"dict_file"
:
"dict_file.txt"
,
}
class
GradTTSTokenizer
(
PreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
...
...
@@ -347,11 +425,11 @@ class GradTTSTokenizer(PreTrainedTokenizer):
x_lengths
=
torch
.
LongTensor
([
x
.
shape
[
-
1
]])
return
x
,
x_lengths
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
=
None
):
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
=
None
):
dict_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"dict_file"
]
)
copyfile
(
self
.
dict_file
,
dict_file
)
return
(
dict_file
,
)
return
(
dict_file
,)
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
ace07110
...
...
@@ -4,11 +4,11 @@ import math
import
torch
from
torch
import
nn
import
tqdm
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers
import
DiffusionPipeline
from
.grad_tts_utils
import
GradTTSTokenizer
# flake8: noqa
...
...
@@ -424,10 +424,14 @@ class GradTTS(DiffusionPipeline):
def
__init__
(
self
,
unet
,
text_encoder
,
noise_scheduler
,
tokenizer
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
text_encoder
=
text_encoder
,
noise_scheduler
=
noise_scheduler
,
tokenizer
=
tokenizer
)
self
.
register_modules
(
unet
=
unet
,
text_encoder
=
text_encoder
,
noise_scheduler
=
noise_scheduler
,
tokenizer
=
tokenizer
)
@
torch
.
no_grad
()
def
__call__
(
self
,
text
,
num_inference_steps
=
50
,
temperature
=
1.3
,
length_scale
=
0.91
,
speaker_id
=
15
,
torch_device
=
None
):
def
__call__
(
self
,
text
,
num_inference_steps
=
50
,
temperature
=
1.3
,
length_scale
=
0.91
,
speaker_id
=
15
,
torch_device
=
None
):
if
torch_device
is
None
:
torch_device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
@@ -439,7 +443,7 @@ class GradTTS(DiffusionPipeline):
x_lengths
=
x_lengths
.
to
(
torch_device
)
if
speaker_id
is
not
None
:
speaker_id
=
torch
.
LongTensor
([
speaker_id
]).
to
(
torch_device
)
speaker_id
=
torch
.
LongTensor
([
speaker_id
]).
to
(
torch_device
)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x
,
logw
,
x_mask
=
self
.
text_encoder
(
x
,
x_lengths
)
...
...
@@ -465,7 +469,7 @@ class GradTTS(DiffusionPipeline):
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
...
...
src/diffusers/schedulers/__init__.py
View file @
ace07110
...
...
@@ -19,6 +19,6 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment