Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ace07110
Commit
ace07110
authored
Jun 16, 2022
by
patil-suraj
Browse files
style
parent
988369a0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
198 additions
and
115 deletions
+198
-115
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+3
-2
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+170
-92
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+24
-20
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
No files found.
src/diffusers/models/unet_grad_tts.py
View file @
ace07110
...
@@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if
n_spks
>
1
:
if
n_spks
>
1
:
self
.
spk_emb
=
torch
.
nn
.
Embedding
(
n_spks
,
spk_emb_dim
)
self
.
spk_emb
=
torch
.
nn
.
Embedding
(
n_spks
,
spk_emb_dim
)
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
))
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
...
...
src/diffusers/pipelines/grad_tts_utils.py
View file @
ace07110
# tokenizer
# tokenizer
import
re
import
os
import
os
import
re
from
shutil
import
copyfile
from
shutil
import
copyfile
import
torch
import
torch
try
:
try
:
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
except
:
except
:
...
@@ -25,17 +26,95 @@ except:
...
@@ -25,17 +26,95 @@ except:
valid_symbols
=
[
valid_symbols
=
[
'AA'
,
'AA0'
,
'AA1'
,
'AA2'
,
'AE'
,
'AE0'
,
'AE1'
,
'AE2'
,
'AH'
,
'AH0'
,
'AH1'
,
'AH2'
,
"AA"
,
'AO'
,
'AO0'
,
'AO1'
,
'AO2'
,
'AW'
,
'AW0'
,
'AW1'
,
'AW2'
,
'AY'
,
'AY0'
,
'AY1'
,
'AY2'
,
"AA0"
,
'B'
,
'CH'
,
'D'
,
'DH'
,
'EH'
,
'EH0'
,
'EH1'
,
'EH2'
,
'ER'
,
'ER0'
,
'ER1'
,
'ER2'
,
'EY'
,
"AA1"
,
'EY0'
,
'EY1'
,
'EY2'
,
'F'
,
'G'
,
'HH'
,
'IH'
,
'IH0'
,
'IH1'
,
'IH2'
,
'IY'
,
'IY0'
,
'IY1'
,
"AA2"
,
'IY2'
,
'JH'
,
'K'
,
'L'
,
'M'
,
'N'
,
'NG'
,
'OW'
,
'OW0'
,
'OW1'
,
'OW2'
,
'OY'
,
'OY0'
,
"AE"
,
'OY1'
,
'OY2'
,
'P'
,
'R'
,
'S'
,
'SH'
,
'T'
,
'TH'
,
'UH'
,
'UH0'
,
'UH1'
,
'UH2'
,
'UW'
,
"AE0"
,
'UW0'
,
'UW1'
,
'UW2'
,
'V'
,
'W'
,
'Y'
,
'Z'
,
'ZH'
"AE1"
,
"AE2"
,
"AH"
,
"AH0"
,
"AH1"
,
"AH2"
,
"AO"
,
"AO0"
,
"AO1"
,
"AO2"
,
"AW"
,
"AW0"
,
"AW1"
,
"AW2"
,
"AY"
,
"AY0"
,
"AY1"
,
"AY2"
,
"B"
,
"CH"
,
"D"
,
"DH"
,
"EH"
,
"EH0"
,
"EH1"
,
"EH2"
,
"ER"
,
"ER0"
,
"ER1"
,
"ER2"
,
"EY"
,
"EY0"
,
"EY1"
,
"EY2"
,
"F"
,
"G"
,
"HH"
,
"IH"
,
"IH0"
,
"IH1"
,
"IH2"
,
"IY"
,
"IY0"
,
"IY1"
,
"IY2"
,
"JH"
,
"K"
,
"L"
,
"M"
,
"N"
,
"NG"
,
"OW"
,
"OW0"
,
"OW1"
,
"OW2"
,
"OY"
,
"OY0"
,
"OY1"
,
"OY2"
,
"P"
,
"R"
,
"S"
,
"SH"
,
"T"
,
"TH"
,
"UH"
,
"UH0"
,
"UH1"
,
"UH2"
,
"UW"
,
"UW0"
,
"UW1"
,
"UW2"
,
"V"
,
"W"
,
"Y"
,
"Z"
,
"ZH"
,
]
]
_valid_symbol_set
=
set
(
valid_symbols
)
_valid_symbol_set
=
set
(
valid_symbols
)
def
intersperse
(
lst
,
item
):
def
intersperse
(
lst
,
item
):
# Adds blank symbol
# Adds blank symbol
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
...
@@ -46,7 +125,7 @@ def intersperse(lst, item):
...
@@ -46,7 +125,7 @@ def intersperse(lst, item):
class
CMUDict
:
class
CMUDict
:
def
__init__
(
self
,
file_or_path
,
keep_ambiguous
=
True
):
def
__init__
(
self
,
file_or_path
,
keep_ambiguous
=
True
):
if
isinstance
(
file_or_path
,
str
):
if
isinstance
(
file_or_path
,
str
):
with
open
(
file_or_path
,
encoding
=
'
latin-1
'
)
as
f
:
with
open
(
file_or_path
,
encoding
=
"
latin-1
"
)
as
f
:
entries
=
_parse_cmudict
(
f
)
entries
=
_parse_cmudict
(
f
)
else
:
else
:
entries
=
_parse_cmudict
(
file_or_path
)
entries
=
_parse_cmudict
(
file_or_path
)
...
@@ -61,15 +140,15 @@ class CMUDict:
...
@@ -61,15 +140,15 @@ class CMUDict:
return
self
.
_entries
.
get
(
word
.
upper
())
return
self
.
_entries
.
get
(
word
.
upper
())
_alt_re
=
re
.
compile
(
r
'
\([0-9]+\)
'
)
_alt_re
=
re
.
compile
(
r
"
\([0-9]+\)
"
)
def
_parse_cmudict
(
file
):
def
_parse_cmudict
(
file
):
cmudict
=
{}
cmudict
=
{}
for
line
in
file
:
for
line
in
file
:
if
len
(
line
)
and
(
line
[
0
]
>=
'A'
and
line
[
0
]
<=
'Z'
or
line
[
0
]
==
"'"
):
if
len
(
line
)
and
(
line
[
0
]
>=
"A"
and
line
[
0
]
<=
"Z"
or
line
[
0
]
==
"'"
):
parts
=
line
.
split
(
'
'
)
parts
=
line
.
split
(
"
"
)
word
=
re
.
sub
(
_alt_re
,
''
,
parts
[
0
])
word
=
re
.
sub
(
_alt_re
,
""
,
parts
[
0
])
pronunciation
=
_get_pronunciation
(
parts
[
1
])
pronunciation
=
_get_pronunciation
(
parts
[
1
])
if
pronunciation
:
if
pronunciation
:
if
word
in
cmudict
:
if
word
in
cmudict
:
...
@@ -80,36 +159,38 @@ def _parse_cmudict(file):
...
@@ -80,36 +159,38 @@ def _parse_cmudict(file):
def
_get_pronunciation
(
s
):
def
_get_pronunciation
(
s
):
parts
=
s
.
strip
().
split
(
' '
)
parts
=
s
.
strip
().
split
(
" "
)
for
part
in
parts
:
for
part
in
parts
:
if
part
not
in
_valid_symbol_set
:
if
part
not
in
_valid_symbol_set
:
return
None
return
None
return
' '
.
join
(
parts
)
return
" "
.
join
(
parts
)
_whitespace_re
=
re
.
compile
(
r
"\s+"
)
_whitespace_re
=
re
.
compile
(
r
'\s+'
)
_abbreviations
=
[
_abbreviations
=
[(
re
.
compile
(
'
\\
b%s
\\
.'
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
re
.
compile
(
"
\\
b%s
\\
."
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
(
'mrs'
,
'misess'
),
for
x
in
[
(
'mr'
,
'mister'
),
(
"mrs"
,
"misess"
),
(
'dr'
,
'doctor'
),
(
"mr"
,
"mister"
),
(
'st'
,
'saint'
),
(
"dr"
,
"doctor"
),
(
'co'
,
'company'
),
(
"st"
,
"saint"
),
(
'jr'
,
'junior'
),
(
"co"
,
"company"
),
(
'maj'
,
'major'
),
(
"jr"
,
"junior"
),
(
'gen'
,
'general'
),
(
"maj"
,
"major"
),
(
'drs'
,
'doctors'
),
(
"gen"
,
"general"
),
(
'rev'
,
'reverend'
),
(
"drs"
,
"doctors"
),
(
'lt'
,
'lieutenant'
),
(
"rev"
,
"reverend"
),
(
'hon'
,
'honorable'
),
(
"lt"
,
"lieutenant"
),
(
'sgt'
,
'sergeant'
),
(
"hon"
,
"honorable"
),
(
'capt'
,
'captain'
),
(
"sgt"
,
"sergeant"
),
(
'esq'
,
'esquire'
),
(
"capt"
,
"captain"
),
(
'ltd'
,
'limited'
),
(
"esq"
,
"esquire"
),
(
'col'
,
'colonel'
),
(
"ltd"
,
"limited"
),
(
'ft'
,
'fort'
),
(
"col"
,
"colonel"
),
]]
(
"ft"
,
"fort"
),
]
]
def
expand_abbreviations
(
text
):
def
expand_abbreviations
(
text
):
...
@@ -127,7 +208,7 @@ def lowercase(text):
...
@@ -127,7 +208,7 @@ def lowercase(text):
def
collapse_whitespace
(
text
):
def
collapse_whitespace
(
text
):
return
re
.
sub
(
_whitespace_re
,
' '
,
text
)
return
re
.
sub
(
_whitespace_re
,
" "
,
text
)
def
convert_to_ascii
(
text
):
def
convert_to_ascii
(
text
):
...
@@ -156,46 +237,42 @@ def english_cleaners(text):
...
@@ -156,46 +237,42 @@ def english_cleaners(text):
return
text
return
text
_inflect
=
inflect
.
engine
()
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
'
([0-9][0-9\,]+[0-9])
'
)
_comma_number_re
=
re
.
compile
(
r
"
([0-9][0-9\,]+[0-9])
"
)
_decimal_number_re
=
re
.
compile
(
r
'
([0-9]+\.[0-9]+)
'
)
_decimal_number_re
=
re
.
compile
(
r
"
([0-9]+\.[0-9]+)
"
)
_pounds_re
=
re
.
compile
(
r
'
£([0-9\,]*[0-9]+)
'
)
_pounds_re
=
re
.
compile
(
r
"
£([0-9\,]*[0-9]+)
"
)
_dollars_re
=
re
.
compile
(
r
'
\$([0-9\.\,]*[0-9]+)
'
)
_dollars_re
=
re
.
compile
(
r
"
\$([0-9\.\,]*[0-9]+)
"
)
_ordinal_re
=
re
.
compile
(
r
'
[0-9]+(st|nd|rd|th)
'
)
_ordinal_re
=
re
.
compile
(
r
"
[0-9]+(st|nd|rd|th)
"
)
_number_re
=
re
.
compile
(
r
'
[0-9]+
'
)
_number_re
=
re
.
compile
(
r
"
[0-9]+
"
)
def
_remove_commas
(
m
):
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
','
,
''
)
return
m
.
group
(
1
).
replace
(
","
,
""
)
def
_expand_decimal_point
(
m
):
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
'.'
,
'
point
'
)
return
m
.
group
(
1
).
replace
(
"."
,
"
point
"
)
def
_expand_dollars
(
m
):
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
'.'
)
parts
=
match
.
split
(
"."
)
if
len
(
parts
)
>
2
:
if
len
(
parts
)
>
2
:
return
match
+
'
dollars
'
return
match
+
"
dollars
"
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
if
dollars
and
cents
:
dollar_unit
=
'
dollar
'
if
dollars
==
1
else
'
dollars
'
dollar_unit
=
"
dollar
"
if
dollars
==
1
else
"
dollars
"
cent_unit
=
'
cent
'
if
cents
==
1
else
'
cents
'
cent_unit
=
"
cent
"
if
cents
==
1
else
"
cents
"
return
'
%s %s, %s %s
'
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
return
"
%s %s, %s %s
"
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
elif
dollars
:
elif
dollars
:
dollar_unit
=
'
dollar
'
if
dollars
==
1
else
'
dollars
'
dollar_unit
=
"
dollar
"
if
dollars
==
1
else
"
dollars
"
return
'
%s %s
'
%
(
dollars
,
dollar_unit
)
return
"
%s %s
"
%
(
dollars
,
dollar_unit
)
elif
cents
:
elif
cents
:
cent_unit
=
'
cent
'
if
cents
==
1
else
'
cents
'
cent_unit
=
"
cent
"
if
cents
==
1
else
"
cents
"
return
'
%s %s
'
%
(
cents
,
cent_unit
)
return
"
%s %s
"
%
(
cents
,
cent_unit
)
else
:
else
:
return
'
zero dollars
'
return
"
zero dollars
"
def
_expand_ordinal
(
m
):
def
_expand_ordinal
(
m
):
...
@@ -206,37 +283,37 @@ def _expand_number(m):
...
@@ -206,37 +283,37 @@ def _expand_number(m):
num
=
int
(
m
.
group
(
0
))
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
if
num
==
2000
:
return
'
two thousand
'
return
"
two thousand
"
elif
num
>
2000
and
num
<
2010
:
elif
num
>
2000
and
num
<
2010
:
return
'
two thousand
'
+
_inflect
.
number_to_words
(
num
%
100
)
return
"
two thousand
"
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
'
hundred
'
return
_inflect
.
number_to_words
(
num
//
100
)
+
"
hundred
"
else
:
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
''
,
zero
=
'oh'
,
return
_inflect
.
number_to_words
(
num
,
andword
=
""
,
zero
=
"oh"
,
group
=
2
).
replace
(
", "
,
" "
)
group
=
2
).
replace
(
', '
,
' '
)
else
:
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
''
)
return
_inflect
.
number_to_words
(
num
,
andword
=
""
)
def
normalize_numbers
(
text
):
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
'
\1 pounds
'
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
"
\1 pounds
"
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
return
text
""" from https://github.com/keithito/tacotron """
""" from https://github.com/keithito/tacotron """
_pad
=
'_'
_pad
=
"_"
_punctuation
=
'!
\
'
(),.:;?
'
_punctuation
=
"!
'(),.:;?
"
_special
=
'-'
_special
=
"-"
_letters
=
'
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
'
_letters
=
"
ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
"
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet
=
[
'@'
+
s
for
s
in
valid_symbols
]
_arpabet
=
[
"@"
+
s
for
s
in
valid_symbols
]
# Export all symbols:
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_special
)
+
list
(
_punctuation
)
+
list
(
_letters
)
+
_arpabet
symbols
=
[
_pad
]
+
list
(
_special
)
+
list
(
_punctuation
)
+
list
(
_letters
)
+
_arpabet
...
@@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab
...
@@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
_curly_re
=
re
.
compile
(
r
'
(.*?)\{(.+?)\}(.*)
'
)
_curly_re
=
re
.
compile
(
r
"
(.*?)\{(.+?)\}(.*)
"
)
def
get_arpabet
(
word
,
dictionary
):
def
get_arpabet
(
word
,
dictionary
):
...
@@ -257,7 +334,7 @@ def get_arpabet(word, dictionary):
...
@@ -257,7 +334,7 @@ def get_arpabet(word, dictionary):
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
'''
Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
"""
Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
...
@@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
...
@@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
Returns:
Returns:
List of integers corresponding to the symbols in the text
List of integers corresponding to the symbols in the text
'''
"""
sequence
=
[]
sequence
=
[]
space
=
_symbols_to_sequence
(
' '
)
space
=
_symbols_to_sequence
(
" "
)
# Check for curly braces and treat their contents as ARPAbet:
# Check for curly braces and treat their contents as ARPAbet:
while
len
(
text
):
while
len
(
text
):
m
=
_curly_re
.
match
(
text
)
m
=
_curly_re
.
match
(
text
)
...
@@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
...
@@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
def
sequence_to_text
(
sequence
):
def
sequence_to_text
(
sequence
):
'''
Converts a sequence of IDs back to a string
'''
"""
Converts a sequence of IDs back to a string
"""
result
=
''
result
=
""
for
symbol_id
in
sequence
:
for
symbol_id
in
sequence
:
if
symbol_id
in
_id_to_symbol
:
if
symbol_id
in
_id_to_symbol
:
s
=
_id_to_symbol
[
symbol_id
]
s
=
_id_to_symbol
[
symbol_id
]
# Enclose ARPAbet back in curly braces:
# Enclose ARPAbet back in curly braces:
if
len
(
s
)
>
1
and
s
[
0
]
==
'@'
:
if
len
(
s
)
>
1
and
s
[
0
]
==
"@"
:
s
=
'
{%s}
'
%
s
[
1
:]
s
=
"
{%s}
"
%
s
[
1
:]
result
+=
s
result
+=
s
return
result
.
replace
(
'
}{
'
,
' '
)
return
result
.
replace
(
"
}{
"
,
" "
)
def
_clean_text
(
text
,
cleaner_names
):
def
_clean_text
(
text
,
cleaner_names
):
...
@@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols):
...
@@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols):
def
_arpabet_to_sequence
(
text
):
def
_arpabet_to_sequence
(
text
):
return
_symbols_to_sequence
([
'@'
+
s
for
s
in
text
.
split
()])
return
_symbols_to_sequence
([
"@"
+
s
for
s
in
text
.
split
()])
def
_should_keep_symbol
(
s
):
def
_should_keep_symbol
(
s
):
return
s
in
_symbol_to_id
and
s
!=
'_'
and
s
!=
'~'
return
s
in
_symbol_to_id
and
s
!=
"_"
and
s
!=
"~"
VOCAB_FILES_NAMES
=
{
VOCAB_FILES_NAMES
=
{
"dict_file"
:
"dict_file.txt"
,
"dict_file"
:
"dict_file.txt"
,
}
}
class
GradTTSTokenizer
(
PreTrainedTokenizer
):
class
GradTTSTokenizer
(
PreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
vocab_files_names
=
VOCAB_FILES_NAMES
...
@@ -347,11 +425,11 @@ class GradTTSTokenizer(PreTrainedTokenizer):
...
@@ -347,11 +425,11 @@ class GradTTSTokenizer(PreTrainedTokenizer):
x_lengths
=
torch
.
LongTensor
([
x
.
shape
[
-
1
]])
x_lengths
=
torch
.
LongTensor
([
x
.
shape
[
-
1
]])
return
x
,
x_lengths
return
x
,
x_lengths
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
=
None
):
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
=
None
):
dict_file
=
os
.
path
.
join
(
dict_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"dict_file"
]
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"dict_file"
]
)
)
copyfile
(
self
.
dict_file
,
dict_file
)
copyfile
(
self
.
dict_file
,
dict_file
)
return
(
dict_file
,
)
return
(
dict_file
,)
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
ace07110
...
@@ -4,11 +4,11 @@ import math
...
@@ -4,11 +4,11 @@ import math
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
tqdm
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers
import
DiffusionPipeline
from
.grad_tts_utils
import
GradTTSTokenizer
# flake8: noqa
from
.grad_tts_utils
import
GradTTSTokenizer
# flake8: noqa
...
@@ -424,10 +424,14 @@ class GradTTS(DiffusionPipeline):
...
@@ -424,10 +424,14 @@ class GradTTS(DiffusionPipeline):
def
__init__
(
self
,
unet
,
text_encoder
,
noise_scheduler
,
tokenizer
):
def
__init__
(
self
,
unet
,
text_encoder
,
noise_scheduler
,
tokenizer
):
super
().
__init__
()
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
text_encoder
=
text_encoder
,
noise_scheduler
=
noise_scheduler
,
tokenizer
=
tokenizer
)
self
.
register_modules
(
unet
=
unet
,
text_encoder
=
text_encoder
,
noise_scheduler
=
noise_scheduler
,
tokenizer
=
tokenizer
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
text
,
num_inference_steps
=
50
,
temperature
=
1.3
,
length_scale
=
0.91
,
speaker_id
=
15
,
torch_device
=
None
):
def
__call__
(
self
,
text
,
num_inference_steps
=
50
,
temperature
=
1.3
,
length_scale
=
0.91
,
speaker_id
=
15
,
torch_device
=
None
):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
torch_device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
@@ -439,7 +443,7 @@ class GradTTS(DiffusionPipeline):
...
@@ -439,7 +443,7 @@ class GradTTS(DiffusionPipeline):
x_lengths
=
x_lengths
.
to
(
torch_device
)
x_lengths
=
x_lengths
.
to
(
torch_device
)
if
speaker_id
is
not
None
:
if
speaker_id
is
not
None
:
speaker_id
=
torch
.
LongTensor
([
speaker_id
]).
to
(
torch_device
)
speaker_id
=
torch
.
LongTensor
([
speaker_id
]).
to
(
torch_device
)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x
,
logw
,
x_mask
=
self
.
text_encoder
(
x
,
x_lengths
)
mu_x
,
logw
,
x_mask
=
self
.
text_encoder
(
x
,
x_lengths
)
...
@@ -465,7 +469,7 @@ class GradTTS(DiffusionPipeline):
...
@@ -465,7 +469,7 @@ class GradTTS(DiffusionPipeline):
xt
=
z
*
y_mask
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
h
=
1.0
/
num_inference_steps
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
...
...
src/diffusers/schedulers/__init__.py
View file @
ace07110
...
@@ -19,6 +19,6 @@
...
@@ -19,6 +19,6 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment