Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
986cc9b2
Commit
986cc9b2
authored
Jun 16, 2022
by
patil-suraj
Browse files
add tokenizer
parent
304d4d90
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
341 additions
and
0 deletions
+341
-0
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+341
-0
No files found.
src/diffusers/pipelines/grad_tts_utils.py
0 → 100644
View file @
986cc9b2
# tokenizer
import
re
import
torch
from
transformers
import
PreTrainedTokenizer
try
:
from
unidecode
import
unidecode
except
:
print
(
"unidecode is not installed"
)
pass
try
:
import
inflect
except
:
print
(
"inflect is not installed"
)
pass
valid_symbols
=
[
'AA'
,
'AA0'
,
'AA1'
,
'AA2'
,
'AE'
,
'AE0'
,
'AE1'
,
'AE2'
,
'AH'
,
'AH0'
,
'AH1'
,
'AH2'
,
'AO'
,
'AO0'
,
'AO1'
,
'AO2'
,
'AW'
,
'AW0'
,
'AW1'
,
'AW2'
,
'AY'
,
'AY0'
,
'AY1'
,
'AY2'
,
'B'
,
'CH'
,
'D'
,
'DH'
,
'EH'
,
'EH0'
,
'EH1'
,
'EH2'
,
'ER'
,
'ER0'
,
'ER1'
,
'ER2'
,
'EY'
,
'EY0'
,
'EY1'
,
'EY2'
,
'F'
,
'G'
,
'HH'
,
'IH'
,
'IH0'
,
'IH1'
,
'IH2'
,
'IY'
,
'IY0'
,
'IY1'
,
'IY2'
,
'JH'
,
'K'
,
'L'
,
'M'
,
'N'
,
'NG'
,
'OW'
,
'OW0'
,
'OW1'
,
'OW2'
,
'OY'
,
'OY0'
,
'OY1'
,
'OY2'
,
'P'
,
'R'
,
'S'
,
'SH'
,
'T'
,
'TH'
,
'UH'
,
'UH0'
,
'UH1'
,
'UH2'
,
'UW'
,
'UW0'
,
'UW1'
,
'UW2'
,
'V'
,
'W'
,
'Y'
,
'Z'
,
'ZH'
]
_valid_symbol_set
=
set
(
valid_symbols
)
def
intersperse
(
lst
,
item
):
# Adds blank symbol
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
result
[
1
::
2
]
=
lst
return
result
class
CMUDict
:
def
__init__
(
self
,
file_or_path
,
keep_ambiguous
=
True
):
if
isinstance
(
file_or_path
,
str
):
with
open
(
file_or_path
,
encoding
=
'latin-1'
)
as
f
:
entries
=
_parse_cmudict
(
f
)
else
:
entries
=
_parse_cmudict
(
file_or_path
)
if
not
keep_ambiguous
:
entries
=
{
word
:
pron
for
word
,
pron
in
entries
.
items
()
if
len
(
pron
)
==
1
}
self
.
_entries
=
entries
def
__len__
(
self
):
return
len
(
self
.
_entries
)
def
lookup
(
self
,
word
):
return
self
.
_entries
.
get
(
word
.
upper
())
_alt_re
=
re
.
compile
(
r
'\([0-9]+\)'
)
def
_parse_cmudict
(
file
):
cmudict
=
{}
for
line
in
file
:
if
len
(
line
)
and
(
line
[
0
]
>=
'A'
and
line
[
0
]
<=
'Z'
or
line
[
0
]
==
"'"
):
parts
=
line
.
split
(
' '
)
word
=
re
.
sub
(
_alt_re
,
''
,
parts
[
0
])
pronunciation
=
_get_pronunciation
(
parts
[
1
])
if
pronunciation
:
if
word
in
cmudict
:
cmudict
[
word
].
append
(
pronunciation
)
else
:
cmudict
[
word
]
=
[
pronunciation
]
return
cmudict
def
_get_pronunciation
(
s
):
parts
=
s
.
strip
().
split
(
' '
)
for
part
in
parts
:
if
part
not
in
_valid_symbol_set
:
return
None
return
' '
.
join
(
parts
)
_whitespace_re
=
re
.
compile
(
r
'\s+'
)
_abbreviations
=
[(
re
.
compile
(
'
\\
b%s
\\
.'
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
'mrs'
,
'misess'
),
(
'mr'
,
'mister'
),
(
'dr'
,
'doctor'
),
(
'st'
,
'saint'
),
(
'co'
,
'company'
),
(
'jr'
,
'junior'
),
(
'maj'
,
'major'
),
(
'gen'
,
'general'
),
(
'drs'
,
'doctors'
),
(
'rev'
,
'reverend'
),
(
'lt'
,
'lieutenant'
),
(
'hon'
,
'honorable'
),
(
'sgt'
,
'sergeant'
),
(
'capt'
,
'captain'
),
(
'esq'
,
'esquire'
),
(
'ltd'
,
'limited'
),
(
'col'
,
'colonel'
),
(
'ft'
,
'fort'
),
]]
def
expand_abbreviations
(
text
):
for
regex
,
replacement
in
_abbreviations
:
text
=
re
.
sub
(
regex
,
replacement
,
text
)
return
text
def
expand_numbers
(
text
):
return
normalize_numbers
(
text
)
def
lowercase
(
text
):
return
text
.
lower
()
def
collapse_whitespace
(
text
):
return
re
.
sub
(
_whitespace_re
,
' '
,
text
)
def
convert_to_ascii
(
text
):
return
unidecode
(
text
)
def
basic_cleaners
(
text
):
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
transliteration_cleaners
(
text
):
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
english_cleaners
(
text
):
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
expand_numbers
(
text
)
text
=
expand_abbreviations
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
'([0-9][0-9\,]+[0-9])'
)
_decimal_number_re
=
re
.
compile
(
r
'([0-9]+\.[0-9]+)'
)
_pounds_re
=
re
.
compile
(
r
'£([0-9\,]*[0-9]+)'
)
_dollars_re
=
re
.
compile
(
r
'\$([0-9\.\,]*[0-9]+)'
)
_ordinal_re
=
re
.
compile
(
r
'[0-9]+(st|nd|rd|th)'
)
_number_re
=
re
.
compile
(
r
'[0-9]+'
)
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
','
,
''
)
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
'.'
,
' point '
)
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
'.'
)
if
len
(
parts
)
>
2
:
return
match
+
' dollars'
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
dollar_unit
=
'dollar'
if
dollars
==
1
else
'dollars'
cent_unit
=
'cent'
if
cents
==
1
else
'cents'
return
'%s %s, %s %s'
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
elif
dollars
:
dollar_unit
=
'dollar'
if
dollars
==
1
else
'dollars'
return
'%s %s'
%
(
dollars
,
dollar_unit
)
elif
cents
:
cent_unit
=
'cent'
if
cents
==
1
else
'cents'
return
'%s %s'
%
(
cents
,
cent_unit
)
else
:
return
'zero dollars'
def
_expand_ordinal
(
m
):
return
_inflect
.
number_to_words
(
m
.
group
(
0
))
def
_expand_number
(
m
):
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
return
'two thousand'
elif
num
>
2000
and
num
<
2010
:
return
'two thousand '
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
' hundred'
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
''
,
zero
=
'oh'
,
group
=
2
).
replace
(
', '
,
' '
)
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
''
)
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
'\1 pounds'
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
""" from https://github.com/keithito/tacotron """
_pad
=
'_'
_punctuation
=
'!
\'
(),.:;? '
_special
=
'-'
_letters
=
'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet
=
[
'@'
+
s
for
s
in
valid_symbols
]
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_special
)
+
list
(
_punctuation
)
+
list
(
_letters
)
+
_arpabet
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
_curly_re
=
re
.
compile
(
r
'(.*?)\{(.+?)\}(.*)'
)
def
get_arpabet
(
word
,
dictionary
):
word_arpabet
=
dictionary
.
lookup
(
word
)
if
word_arpabet
is
not
None
:
return
"{"
+
word_arpabet
[
0
]
+
"}"
else
:
return
word
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
dictionary: arpabet class with arpabet dictionary
Returns:
List of integers corresponding to the symbols in the text
'''
sequence
=
[]
space
=
_symbols_to_sequence
(
' '
)
# Check for curly braces and treat their contents as ARPAbet:
while
len
(
text
):
m
=
_curly_re
.
match
(
text
)
if
not
m
:
clean_text
=
_clean_text
(
text
,
cleaner_names
)
if
dictionary
is
not
None
:
clean_text
=
[
get_arpabet
(
w
,
dictionary
)
for
w
in
clean_text
.
split
(
" "
)]
for
i
in
range
(
len
(
clean_text
)):
t
=
clean_text
[
i
]
if
t
.
startswith
(
"{"
):
sequence
+=
_arpabet_to_sequence
(
t
[
1
:
-
1
])
else
:
sequence
+=
_symbols_to_sequence
(
t
)
sequence
+=
space
else
:
sequence
+=
_symbols_to_sequence
(
clean_text
)
break
sequence
+=
_symbols_to_sequence
(
_clean_text
(
m
.
group
(
1
),
cleaner_names
))
sequence
+=
_arpabet_to_sequence
(
m
.
group
(
2
))
text
=
m
.
group
(
3
)
# remove trailing space
if
dictionary
is
not
None
:
sequence
=
sequence
[:
-
1
]
if
sequence
[
-
1
]
==
space
[
0
]
else
sequence
return
sequence
def
sequence_to_text
(
sequence
):
'''Converts a sequence of IDs back to a string'''
result
=
''
for
symbol_id
in
sequence
:
if
symbol_id
in
_id_to_symbol
:
s
=
_id_to_symbol
[
symbol_id
]
# Enclose ARPAbet back in curly braces:
if
len
(
s
)
>
1
and
s
[
0
]
==
'@'
:
s
=
'{%s}'
%
s
[
1
:]
result
+=
s
return
result
.
replace
(
'}{'
,
' '
)
def
_clean_text
(
text
,
cleaner_names
):
for
cleaner
in
cleaner_names
:
text
=
cleaner
(
text
)
return
text
def
_symbols_to_sequence
(
symbols
):
return
[
_symbol_to_id
[
s
]
for
s
in
symbols
if
_should_keep_symbol
(
s
)]
def
_arpabet_to_sequence
(
text
):
return
_symbols_to_sequence
([
'@'
+
s
for
s
in
text
.
split
()])
def
_should_keep_symbol
(
s
):
return
s
in
_symbol_to_id
and
s
!=
'_'
and
s
!=
'~'
VOCAB_FILES_NAMES
=
{
"dict_file"
:
"merges.txt"
,
}
class
GradTTSTokenizer
(
PreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
def
__init__
(
self
,
dict_file
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cmu
=
CMUDict
(
dict_file
)
def
__call__
(
self
,
text
):
x
=
torch
.
LongTensor
(
intersperse
(
text_to_sequence
(
text
,
dictionary
=
self
.
cmu
),
len
(
symbols
)))[
None
]
x_lengths
=
torch
.
LongTensor
([
x
.
shape
[
-
1
]])
return
x
.
shape
,
x_lengths
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment