Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7cabc0cd
Unverified
Commit
7cabc0cd
authored
Jun 16, 2022
by
Suraj Patil
Committed by
GitHub
Jun 16, 2022
Browse files
Add GradTTS
Add GradTTS
parents
bed32182
c2e48b23
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
559 additions
and
3 deletions
+559
-3
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-2
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+7
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+435
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+64
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+49
-0
No files found.
src/diffusers/__init__.py
View file @
7cabc0cd
...
@@ -10,6 +10,6 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
...
@@ -10,6 +10,6 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
LatentDiffusion
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
GradTTS
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/models/unet_grad_tts.py
View file @
7cabc0cd
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
try
:
try
:
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
except
:
except
:
print
(
"Einops is not installed"
)
print
(
"Einops is not installed"
)
pass
pass
...
@@ -144,9 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -144,9 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
pe_scale
=
pe_scale
self
.
pe_scale
=
pe_scale
if
n_spks
>
1
:
if
n_spks
>
1
:
self
.
spk_emb
=
torch
.
nn
.
Embedding
(
n_spks
,
spk_emb_dim
)
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
...
@@ -189,6 +191,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
...
@@ -189,6 +191,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
if
self
.
n_spks
>
1
:
# Get speaker embedding
spk
=
self
.
spk_emb
(
spk
)
if
not
isinstance
(
spk
,
type
(
None
)):
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
s
=
self
.
spk_mlp
(
spk
)
...
...
src/diffusers/pipelines/__init__.py
View file @
7cabc0cd
from
.pipeline_bddm
import
BDDM
from
.pipeline_bddm
import
BDDM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_grad_tts
import
GradTTS
try
:
try
:
...
...
src/diffusers/pipelines/grad_tts_utils.py
0 → 100644
View file @
7cabc0cd
# tokenizer
import
os
import
re
from
shutil
import
copyfile
import
torch
try
:
from
transformers
import
PreTrainedTokenizer
except
:
print
(
"transformers is not installed"
)
try
:
from
unidecode
import
unidecode
except
:
print
(
"unidecode is not installed"
)
pass
try
:
import
inflect
except
:
print
(
"inflect is not installed"
)
pass
valid_symbols
=
[
"AA"
,
"AA0"
,
"AA1"
,
"AA2"
,
"AE"
,
"AE0"
,
"AE1"
,
"AE2"
,
"AH"
,
"AH0"
,
"AH1"
,
"AH2"
,
"AO"
,
"AO0"
,
"AO1"
,
"AO2"
,
"AW"
,
"AW0"
,
"AW1"
,
"AW2"
,
"AY"
,
"AY0"
,
"AY1"
,
"AY2"
,
"B"
,
"CH"
,
"D"
,
"DH"
,
"EH"
,
"EH0"
,
"EH1"
,
"EH2"
,
"ER"
,
"ER0"
,
"ER1"
,
"ER2"
,
"EY"
,
"EY0"
,
"EY1"
,
"EY2"
,
"F"
,
"G"
,
"HH"
,
"IH"
,
"IH0"
,
"IH1"
,
"IH2"
,
"IY"
,
"IY0"
,
"IY1"
,
"IY2"
,
"JH"
,
"K"
,
"L"
,
"M"
,
"N"
,
"NG"
,
"OW"
,
"OW0"
,
"OW1"
,
"OW2"
,
"OY"
,
"OY0"
,
"OY1"
,
"OY2"
,
"P"
,
"R"
,
"S"
,
"SH"
,
"T"
,
"TH"
,
"UH"
,
"UH0"
,
"UH1"
,
"UH2"
,
"UW"
,
"UW0"
,
"UW1"
,
"UW2"
,
"V"
,
"W"
,
"Y"
,
"Z"
,
"ZH"
,
]
_valid_symbol_set
=
set
(
valid_symbols
)
def
intersperse
(
lst
,
item
):
# Adds blank symbol
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
result
[
1
::
2
]
=
lst
return
result
class
CMUDict
:
def
__init__
(
self
,
file_or_path
,
keep_ambiguous
=
True
):
if
isinstance
(
file_or_path
,
str
):
with
open
(
file_or_path
,
encoding
=
"latin-1"
)
as
f
:
entries
=
_parse_cmudict
(
f
)
else
:
entries
=
_parse_cmudict
(
file_or_path
)
if
not
keep_ambiguous
:
entries
=
{
word
:
pron
for
word
,
pron
in
entries
.
items
()
if
len
(
pron
)
==
1
}
self
.
_entries
=
entries
def
__len__
(
self
):
return
len
(
self
.
_entries
)
def
lookup
(
self
,
word
):
return
self
.
_entries
.
get
(
word
.
upper
())
_alt_re
=
re
.
compile
(
r
"\([0-9]+\)"
)
def
_parse_cmudict
(
file
):
cmudict
=
{}
for
line
in
file
:
if
len
(
line
)
and
(
line
[
0
]
>=
"A"
and
line
[
0
]
<=
"Z"
or
line
[
0
]
==
"'"
):
parts
=
line
.
split
(
" "
)
word
=
re
.
sub
(
_alt_re
,
""
,
parts
[
0
])
pronunciation
=
_get_pronunciation
(
parts
[
1
])
if
pronunciation
:
if
word
in
cmudict
:
cmudict
[
word
].
append
(
pronunciation
)
else
:
cmudict
[
word
]
=
[
pronunciation
]
return
cmudict
def
_get_pronunciation
(
s
):
parts
=
s
.
strip
().
split
(
" "
)
for
part
in
parts
:
if
part
not
in
_valid_symbol_set
:
return
None
return
" "
.
join
(
parts
)
_whitespace_re
=
re
.
compile
(
r
"\s+"
)
_abbreviations
=
[
(
re
.
compile
(
"
\\
b%s
\\
."
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
"mrs"
,
"misess"
),
(
"mr"
,
"mister"
),
(
"dr"
,
"doctor"
),
(
"st"
,
"saint"
),
(
"co"
,
"company"
),
(
"jr"
,
"junior"
),
(
"maj"
,
"major"
),
(
"gen"
,
"general"
),
(
"drs"
,
"doctors"
),
(
"rev"
,
"reverend"
),
(
"lt"
,
"lieutenant"
),
(
"hon"
,
"honorable"
),
(
"sgt"
,
"sergeant"
),
(
"capt"
,
"captain"
),
(
"esq"
,
"esquire"
),
(
"ltd"
,
"limited"
),
(
"col"
,
"colonel"
),
(
"ft"
,
"fort"
),
]
]
def
expand_abbreviations
(
text
):
for
regex
,
replacement
in
_abbreviations
:
text
=
re
.
sub
(
regex
,
replacement
,
text
)
return
text
def
expand_numbers
(
text
):
return
normalize_numbers
(
text
)
def
lowercase
(
text
):
return
text
.
lower
()
def
collapse_whitespace
(
text
):
return
re
.
sub
(
_whitespace_re
,
" "
,
text
)
def
convert_to_ascii
(
text
):
return
unidecode
(
text
)
def
basic_cleaners
(
text
):
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
transliteration_cleaners
(
text
):
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
english_cleaners
(
text
):
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
expand_numbers
(
text
)
text
=
expand_abbreviations
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
"([0-9][0-9\,]+[0-9])"
)
_decimal_number_re
=
re
.
compile
(
r
"([0-9]+\.[0-9]+)"
)
_pounds_re
=
re
.
compile
(
r
"£([0-9\,]*[0-9]+)"
)
_dollars_re
=
re
.
compile
(
r
"\$([0-9\.\,]*[0-9]+)"
)
_ordinal_re
=
re
.
compile
(
r
"[0-9]+(st|nd|rd|th)"
)
_number_re
=
re
.
compile
(
r
"[0-9]+"
)
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
","
,
""
)
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
"."
,
" point "
)
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
"."
)
if
len
(
parts
)
>
2
:
return
match
+
" dollars"
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
"%s %s, %s %s"
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
elif
dollars
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
return
"%s %s"
%
(
dollars
,
dollar_unit
)
elif
cents
:
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
"%s %s"
%
(
cents
,
cent_unit
)
else
:
return
"zero dollars"
def
_expand_ordinal
(
m
):
return
_inflect
.
number_to_words
(
m
.
group
(
0
))
def
_expand_number
(
m
):
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
return
"two thousand"
elif
num
>
2000
and
num
<
2010
:
return
"two thousand "
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
" hundred"
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
,
zero
=
"oh"
,
group
=
2
).
replace
(
", "
,
" "
)
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
)
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
"\1 pounds"
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
""" from https://github.com/keithito/tacotron """
_pad
=
"_"
_punctuation
=
"!'(),.:;? "
_special
=
"-"
_letters
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet
=
[
"@"
+
s
for
s
in
valid_symbols
]
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_special
)
+
list
(
_punctuation
)
+
list
(
_letters
)
+
_arpabet
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
_curly_re
=
re
.
compile
(
r
"(.*?)\{(.+?)\}(.*)"
)
def
get_arpabet
(
word
,
dictionary
):
word_arpabet
=
dictionary
.
lookup
(
word
)
if
word_arpabet
is
not
None
:
return
"{"
+
word_arpabet
[
0
]
+
"}"
else
:
return
word
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
dictionary: arpabet class with arpabet dictionary
Returns:
List of integers corresponding to the symbols in the text
"""
sequence
=
[]
space
=
_symbols_to_sequence
(
" "
)
# Check for curly braces and treat their contents as ARPAbet:
while
len
(
text
):
m
=
_curly_re
.
match
(
text
)
if
not
m
:
clean_text
=
_clean_text
(
text
,
cleaner_names
)
if
dictionary
is
not
None
:
clean_text
=
[
get_arpabet
(
w
,
dictionary
)
for
w
in
clean_text
.
split
(
" "
)]
for
i
in
range
(
len
(
clean_text
)):
t
=
clean_text
[
i
]
if
t
.
startswith
(
"{"
):
sequence
+=
_arpabet_to_sequence
(
t
[
1
:
-
1
])
else
:
sequence
+=
_symbols_to_sequence
(
t
)
sequence
+=
space
else
:
sequence
+=
_symbols_to_sequence
(
clean_text
)
break
sequence
+=
_symbols_to_sequence
(
_clean_text
(
m
.
group
(
1
),
cleaner_names
))
sequence
+=
_arpabet_to_sequence
(
m
.
group
(
2
))
text
=
m
.
group
(
3
)
# remove trailing space
if
dictionary
is
not
None
:
sequence
=
sequence
[:
-
1
]
if
sequence
[
-
1
]
==
space
[
0
]
else
sequence
return
sequence
def
sequence_to_text
(
sequence
):
"""Converts a sequence of IDs back to a string"""
result
=
""
for
symbol_id
in
sequence
:
if
symbol_id
in
_id_to_symbol
:
s
=
_id_to_symbol
[
symbol_id
]
# Enclose ARPAbet back in curly braces:
if
len
(
s
)
>
1
and
s
[
0
]
==
"@"
:
s
=
"{%s}"
%
s
[
1
:]
result
+=
s
return
result
.
replace
(
"}{"
,
" "
)
def
_clean_text
(
text
,
cleaner_names
):
for
cleaner
in
cleaner_names
:
text
=
cleaner
(
text
)
return
text
def
_symbols_to_sequence
(
symbols
):
return
[
_symbol_to_id
[
s
]
for
s
in
symbols
if
_should_keep_symbol
(
s
)]
def
_arpabet_to_sequence
(
text
):
return
_symbols_to_sequence
([
"@"
+
s
for
s
in
text
.
split
()])
def
_should_keep_symbol
(
s
):
return
s
in
_symbol_to_id
and
s
!=
"_"
and
s
!=
"~"
VOCAB_FILES_NAMES
=
{
"dict_file"
:
"dict_file.txt"
,
}
class
GradTTSTokenizer
(
PreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
def
__init__
(
self
,
dict_file
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cmu
=
CMUDict
(
dict_file
)
self
.
dict_file
=
dict_file
def
__call__
(
self
,
text
):
x
=
torch
.
LongTensor
(
intersperse
(
text_to_sequence
(
text
,
dictionary
=
self
.
cmu
),
len
(
symbols
)))[
None
]
x_lengths
=
torch
.
LongTensor
([
x
.
shape
[
-
1
]])
return
x
,
x_lengths
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
=
None
):
dict_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"dict_file"
]
)
copyfile
(
self
.
dict_file
,
dict_file
)
return
(
dict_file
,)
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
7cabc0cd
...
@@ -5,9 +5,13 @@ import math
...
@@ -5,9 +5,13 @@ import math
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.modeling_utils
import
ModelMixin
from
.grad_tts_utils
import
GradTTSTokenizer
# flake8: noqa
def
sequence_mask
(
length
,
max_length
=
None
):
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
if
max_length
is
None
:
...
@@ -414,3 +418,63 @@ class TextEncoder(ModelMixin, ConfigMixin):
...
@@ -414,3 +418,63 @@ class TextEncoder(ModelMixin, ConfigMixin):
logw
=
self
.
proj_w
(
x_dp
,
x_mask
)
logw
=
self
.
proj_w
(
x_dp
,
x_mask
)
return
mu
,
logw
,
x_mask
return
mu
,
logw
,
x_mask
class
GradTTS
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
text_encoder
,
noise_scheduler
,
tokenizer
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
text_encoder
=
text_encoder
,
noise_scheduler
=
noise_scheduler
,
tokenizer
=
tokenizer
)
@
torch
.
no_grad
()
def
__call__
(
self
,
text
,
num_inference_steps
=
50
,
temperature
=
1.3
,
length_scale
=
0.91
,
speaker_id
=
15
,
torch_device
=
None
):
if
torch_device
is
None
:
torch_device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
x
,
x_lengths
=
self
.
tokenizer
(
text
)
x
=
x
.
to
(
torch_device
)
x_lengths
=
x_lengths
.
to
(
torch_device
)
if
speaker_id
is
not
None
:
speaker_id
=
torch
.
LongTensor
([
speaker_id
]).
to
(
torch_device
)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x
,
logw
,
x_mask
=
self
.
text_encoder
(
x
,
x_lengths
)
w
=
torch
.
exp
(
logw
)
*
x_mask
w_ceil
=
torch
.
ceil
(
w
)
*
length_scale
y_lengths
=
torch
.
clamp_min
(
torch
.
sum
(
w_ceil
,
[
1
,
2
]),
1
).
long
()
y_max_length
=
int
(
y_lengths
.
max
())
y_max_length_
=
fix_len_compatibility
(
y_max_length
)
# Using obtained durations `w` construct alignment map `attn`
y_mask
=
sequence_mask
(
y_lengths
,
y_max_length_
).
unsqueeze
(
1
).
to
(
x_mask
.
dtype
)
attn_mask
=
x_mask
.
unsqueeze
(
-
1
)
*
y_mask
.
unsqueeze
(
2
)
attn
=
generate_path
(
w_ceil
.
squeeze
(
1
),
attn_mask
.
squeeze
(
1
)).
unsqueeze
(
1
)
# Align encoded text and get mu_y
mu_y
=
torch
.
matmul
(
attn
.
squeeze
(
1
).
transpose
(
1
,
2
),
mu_x
.
transpose
(
1
,
2
))
mu_y
=
mu_y
.
transpose
(
1
,
2
)
# Sample latent representation from terminal distribution N(mu_y, I)
z
=
mu_y
+
torch
.
randn_like
(
mu_y
,
device
=
mu_y
.
device
)
/
temperature
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
xt
=
xt
*
y_mask
return
xt
[:,
:,
:
y_max_length
]
src/diffusers/schedulers/__init__.py
View file @
7cabc0cd
...
@@ -19,5 +19,6 @@
...
@@ -19,5 +19,6 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
src/diffusers/schedulers/scheduling_grad_tts.py
0 → 100644
View file @
7cabc0cd
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
)
self
.
timesteps
=
int
(
timesteps
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
sample_noise
(
self
,
timestep
):
noise
=
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
timestep
return
noise
def
step
(
self
,
xt
,
residual
,
mu
,
h
,
timestep
):
noise_t
=
self
.
sample_noise
(
timestep
)
dxt
=
0.5
*
(
mu
-
xt
-
residual
)
dxt
=
dxt
*
noise_t
*
h
xt
=
xt
-
dxt
return
xt
def
__len__
(
self
):
return
self
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment