Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7cabc0cd
"vscode:/vscode.git/clone" did not exist on "bc90c28bc954dc6fec52b979e11f8716240bb758"
Unverified
Commit
7cabc0cd
authored
Jun 16, 2022
by
Suraj Patil
Committed by
GitHub
Jun 16, 2022
Browse files
Add GradTTS
Add GradTTS
parents
bed32182
c2e48b23
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
559 additions
and
3 deletions
+559
-3
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-2
src/diffusers/models/unet_grad_tts.py
src/diffusers/models/unet_grad_tts.py
+7
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/grad_tts_utils.py
src/diffusers/pipelines/grad_tts_utils.py
+435
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+64
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+49
-0
No files found.
src/diffusers/__init__.py
View file @
7cabc0cd
...
...
@@ -10,6 +10,6 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from
.models.unet_grad_tts
import
UNetGradTTSModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
GLIDE
,
PNDM
,
GradTTS
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/models/unet_grad_tts.py
View file @
7cabc0cd
...
...
@@ -4,7 +4,7 @@ import torch
try
:
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
except
:
print
(
"Einops is not installed"
)
pass
...
...
@@ -144,9 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
pe_scale
=
pe_scale
if
n_spks
>
1
:
self
.
spk_emb
=
torch
.
nn
.
Embedding
(
n_spks
,
spk_emb_dim
)
self
.
spk_mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
spk_emb_dim
,
spk_emb_dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
spk_emb_dim
*
4
,
n_feats
)
)
self
.
time_pos_emb
=
SinusoidalPosEmb
(
dim
)
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
dim
,
dim
*
4
),
Mish
(),
torch
.
nn
.
Linear
(
dim
*
4
,
dim
))
...
...
@@ -189,6 +191,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self
.
final_conv
=
torch
.
nn
.
Conv2d
(
dim
,
1
,
1
)
def
forward
(
self
,
x
,
mask
,
mu
,
t
,
spk
=
None
):
if
self
.
n_spks
>
1
:
# Get speaker embedding
spk
=
self
.
spk_emb
(
spk
)
if
not
isinstance
(
spk
,
type
(
None
)):
s
=
self
.
spk_mlp
(
spk
)
...
...
src/diffusers/pipelines/__init__.py
View file @
7cabc0cd
from
.pipeline_bddm
import
BDDM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_grad_tts
import
GradTTS
try
:
...
...
src/diffusers/pipelines/grad_tts_utils.py
0 → 100644
View file @
7cabc0cd
# tokenizer
import
os
import
re
from
shutil
import
copyfile
import
torch
try
:
from
transformers
import
PreTrainedTokenizer
except
:
print
(
"transformers is not installed"
)
try
:
from
unidecode
import
unidecode
except
:
print
(
"unidecode is not installed"
)
pass
try
:
import
inflect
except
:
print
(
"inflect is not installed"
)
pass
valid_symbols
=
[
"AA"
,
"AA0"
,
"AA1"
,
"AA2"
,
"AE"
,
"AE0"
,
"AE1"
,
"AE2"
,
"AH"
,
"AH0"
,
"AH1"
,
"AH2"
,
"AO"
,
"AO0"
,
"AO1"
,
"AO2"
,
"AW"
,
"AW0"
,
"AW1"
,
"AW2"
,
"AY"
,
"AY0"
,
"AY1"
,
"AY2"
,
"B"
,
"CH"
,
"D"
,
"DH"
,
"EH"
,
"EH0"
,
"EH1"
,
"EH2"
,
"ER"
,
"ER0"
,
"ER1"
,
"ER2"
,
"EY"
,
"EY0"
,
"EY1"
,
"EY2"
,
"F"
,
"G"
,
"HH"
,
"IH"
,
"IH0"
,
"IH1"
,
"IH2"
,
"IY"
,
"IY0"
,
"IY1"
,
"IY2"
,
"JH"
,
"K"
,
"L"
,
"M"
,
"N"
,
"NG"
,
"OW"
,
"OW0"
,
"OW1"
,
"OW2"
,
"OY"
,
"OY0"
,
"OY1"
,
"OY2"
,
"P"
,
"R"
,
"S"
,
"SH"
,
"T"
,
"TH"
,
"UH"
,
"UH0"
,
"UH1"
,
"UH2"
,
"UW"
,
"UW0"
,
"UW1"
,
"UW2"
,
"V"
,
"W"
,
"Y"
,
"Z"
,
"ZH"
,
]
_valid_symbol_set
=
set
(
valid_symbols
)
def
intersperse
(
lst
,
item
):
# Adds blank symbol
result
=
[
item
]
*
(
len
(
lst
)
*
2
+
1
)
result
[
1
::
2
]
=
lst
return
result
class
CMUDict
:
def
__init__
(
self
,
file_or_path
,
keep_ambiguous
=
True
):
if
isinstance
(
file_or_path
,
str
):
with
open
(
file_or_path
,
encoding
=
"latin-1"
)
as
f
:
entries
=
_parse_cmudict
(
f
)
else
:
entries
=
_parse_cmudict
(
file_or_path
)
if
not
keep_ambiguous
:
entries
=
{
word
:
pron
for
word
,
pron
in
entries
.
items
()
if
len
(
pron
)
==
1
}
self
.
_entries
=
entries
def
__len__
(
self
):
return
len
(
self
.
_entries
)
def
lookup
(
self
,
word
):
return
self
.
_entries
.
get
(
word
.
upper
())
_alt_re
=
re
.
compile
(
r
"\([0-9]+\)"
)
def
_parse_cmudict
(
file
):
cmudict
=
{}
for
line
in
file
:
if
len
(
line
)
and
(
line
[
0
]
>=
"A"
and
line
[
0
]
<=
"Z"
or
line
[
0
]
==
"'"
):
parts
=
line
.
split
(
" "
)
word
=
re
.
sub
(
_alt_re
,
""
,
parts
[
0
])
pronunciation
=
_get_pronunciation
(
parts
[
1
])
if
pronunciation
:
if
word
in
cmudict
:
cmudict
[
word
].
append
(
pronunciation
)
else
:
cmudict
[
word
]
=
[
pronunciation
]
return
cmudict
def
_get_pronunciation
(
s
):
parts
=
s
.
strip
().
split
(
" "
)
for
part
in
parts
:
if
part
not
in
_valid_symbol_set
:
return
None
return
" "
.
join
(
parts
)
_whitespace_re
=
re
.
compile
(
r
"\s+"
)
_abbreviations
=
[
(
re
.
compile
(
"
\\
b%s
\\
."
%
x
[
0
],
re
.
IGNORECASE
),
x
[
1
])
for
x
in
[
(
"mrs"
,
"misess"
),
(
"mr"
,
"mister"
),
(
"dr"
,
"doctor"
),
(
"st"
,
"saint"
),
(
"co"
,
"company"
),
(
"jr"
,
"junior"
),
(
"maj"
,
"major"
),
(
"gen"
,
"general"
),
(
"drs"
,
"doctors"
),
(
"rev"
,
"reverend"
),
(
"lt"
,
"lieutenant"
),
(
"hon"
,
"honorable"
),
(
"sgt"
,
"sergeant"
),
(
"capt"
,
"captain"
),
(
"esq"
,
"esquire"
),
(
"ltd"
,
"limited"
),
(
"col"
,
"colonel"
),
(
"ft"
,
"fort"
),
]
]
def
expand_abbreviations
(
text
):
for
regex
,
replacement
in
_abbreviations
:
text
=
re
.
sub
(
regex
,
replacement
,
text
)
return
text
def
expand_numbers
(
text
):
return
normalize_numbers
(
text
)
def
lowercase
(
text
):
return
text
.
lower
()
def
collapse_whitespace
(
text
):
return
re
.
sub
(
_whitespace_re
,
" "
,
text
)
def
convert_to_ascii
(
text
):
return
unidecode
(
text
)
def
basic_cleaners
(
text
):
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
transliteration_cleaners
(
text
):
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
def
english_cleaners
(
text
):
text
=
convert_to_ascii
(
text
)
text
=
lowercase
(
text
)
text
=
expand_numbers
(
text
)
text
=
expand_abbreviations
(
text
)
text
=
collapse_whitespace
(
text
)
return
text
_inflect
=
inflect
.
engine
()
_comma_number_re
=
re
.
compile
(
r
"([0-9][0-9\,]+[0-9])"
)
_decimal_number_re
=
re
.
compile
(
r
"([0-9]+\.[0-9]+)"
)
_pounds_re
=
re
.
compile
(
r
"£([0-9\,]*[0-9]+)"
)
_dollars_re
=
re
.
compile
(
r
"\$([0-9\.\,]*[0-9]+)"
)
_ordinal_re
=
re
.
compile
(
r
"[0-9]+(st|nd|rd|th)"
)
_number_re
=
re
.
compile
(
r
"[0-9]+"
)
def
_remove_commas
(
m
):
return
m
.
group
(
1
).
replace
(
","
,
""
)
def
_expand_decimal_point
(
m
):
return
m
.
group
(
1
).
replace
(
"."
,
" point "
)
def
_expand_dollars
(
m
):
match
=
m
.
group
(
1
)
parts
=
match
.
split
(
"."
)
if
len
(
parts
)
>
2
:
return
match
+
" dollars"
dollars
=
int
(
parts
[
0
])
if
parts
[
0
]
else
0
cents
=
int
(
parts
[
1
])
if
len
(
parts
)
>
1
and
parts
[
1
]
else
0
if
dollars
and
cents
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
"%s %s, %s %s"
%
(
dollars
,
dollar_unit
,
cents
,
cent_unit
)
elif
dollars
:
dollar_unit
=
"dollar"
if
dollars
==
1
else
"dollars"
return
"%s %s"
%
(
dollars
,
dollar_unit
)
elif
cents
:
cent_unit
=
"cent"
if
cents
==
1
else
"cents"
return
"%s %s"
%
(
cents
,
cent_unit
)
else
:
return
"zero dollars"
def
_expand_ordinal
(
m
):
return
_inflect
.
number_to_words
(
m
.
group
(
0
))
def
_expand_number
(
m
):
num
=
int
(
m
.
group
(
0
))
if
num
>
1000
and
num
<
3000
:
if
num
==
2000
:
return
"two thousand"
elif
num
>
2000
and
num
<
2010
:
return
"two thousand "
+
_inflect
.
number_to_words
(
num
%
100
)
elif
num
%
100
==
0
:
return
_inflect
.
number_to_words
(
num
//
100
)
+
" hundred"
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
,
zero
=
"oh"
,
group
=
2
).
replace
(
", "
,
" "
)
else
:
return
_inflect
.
number_to_words
(
num
,
andword
=
""
)
def
normalize_numbers
(
text
):
text
=
re
.
sub
(
_comma_number_re
,
_remove_commas
,
text
)
text
=
re
.
sub
(
_pounds_re
,
r
"\1 pounds"
,
text
)
text
=
re
.
sub
(
_dollars_re
,
_expand_dollars
,
text
)
text
=
re
.
sub
(
_decimal_number_re
,
_expand_decimal_point
,
text
)
text
=
re
.
sub
(
_ordinal_re
,
_expand_ordinal
,
text
)
text
=
re
.
sub
(
_number_re
,
_expand_number
,
text
)
return
text
""" from https://github.com/keithito/tacotron """
_pad
=
"_"
_punctuation
=
"!'(),.:;? "
_special
=
"-"
_letters
=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet
=
[
"@"
+
s
for
s
in
valid_symbols
]
# Export all symbols:
symbols
=
[
_pad
]
+
list
(
_special
)
+
list
(
_punctuation
)
+
list
(
_letters
)
+
_arpabet
_symbol_to_id
=
{
s
:
i
for
i
,
s
in
enumerate
(
symbols
)}
_id_to_symbol
=
{
i
:
s
for
i
,
s
in
enumerate
(
symbols
)}
_curly_re
=
re
.
compile
(
r
"(.*?)\{(.+?)\}(.*)"
)
def
get_arpabet
(
word
,
dictionary
):
word_arpabet
=
dictionary
.
lookup
(
word
)
if
word_arpabet
is
not
None
:
return
"{"
+
word_arpabet
[
0
]
+
"}"
else
:
return
word
def
text_to_sequence
(
text
,
cleaner_names
=
[
english_cleaners
],
dictionary
=
None
):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
dictionary: arpabet class with arpabet dictionary
Returns:
List of integers corresponding to the symbols in the text
"""
sequence
=
[]
space
=
_symbols_to_sequence
(
" "
)
# Check for curly braces and treat their contents as ARPAbet:
while
len
(
text
):
m
=
_curly_re
.
match
(
text
)
if
not
m
:
clean_text
=
_clean_text
(
text
,
cleaner_names
)
if
dictionary
is
not
None
:
clean_text
=
[
get_arpabet
(
w
,
dictionary
)
for
w
in
clean_text
.
split
(
" "
)]
for
i
in
range
(
len
(
clean_text
)):
t
=
clean_text
[
i
]
if
t
.
startswith
(
"{"
):
sequence
+=
_arpabet_to_sequence
(
t
[
1
:
-
1
])
else
:
sequence
+=
_symbols_to_sequence
(
t
)
sequence
+=
space
else
:
sequence
+=
_symbols_to_sequence
(
clean_text
)
break
sequence
+=
_symbols_to_sequence
(
_clean_text
(
m
.
group
(
1
),
cleaner_names
))
sequence
+=
_arpabet_to_sequence
(
m
.
group
(
2
))
text
=
m
.
group
(
3
)
# remove trailing space
if
dictionary
is
not
None
:
sequence
=
sequence
[:
-
1
]
if
sequence
[
-
1
]
==
space
[
0
]
else
sequence
return
sequence
def
sequence_to_text
(
sequence
):
"""Converts a sequence of IDs back to a string"""
result
=
""
for
symbol_id
in
sequence
:
if
symbol_id
in
_id_to_symbol
:
s
=
_id_to_symbol
[
symbol_id
]
# Enclose ARPAbet back in curly braces:
if
len
(
s
)
>
1
and
s
[
0
]
==
"@"
:
s
=
"{%s}"
%
s
[
1
:]
result
+=
s
return
result
.
replace
(
"}{"
,
" "
)
def
_clean_text
(
text
,
cleaner_names
):
for
cleaner
in
cleaner_names
:
text
=
cleaner
(
text
)
return
text
def
_symbols_to_sequence
(
symbols
):
return
[
_symbol_to_id
[
s
]
for
s
in
symbols
if
_should_keep_symbol
(
s
)]
def
_arpabet_to_sequence
(
text
):
return
_symbols_to_sequence
([
"@"
+
s
for
s
in
text
.
split
()])
def
_should_keep_symbol
(
s
):
return
s
in
_symbol_to_id
and
s
!=
"_"
and
s
!=
"~"
VOCAB_FILES_NAMES
=
{
"dict_file"
:
"dict_file.txt"
,
}
class
GradTTSTokenizer
(
PreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
def
__init__
(
self
,
dict_file
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
cmu
=
CMUDict
(
dict_file
)
self
.
dict_file
=
dict_file
def
__call__
(
self
,
text
):
x
=
torch
.
LongTensor
(
intersperse
(
text_to_sequence
(
text
,
dictionary
=
self
.
cmu
),
len
(
symbols
)))[
None
]
x_lengths
=
torch
.
LongTensor
([
x
.
shape
[
-
1
]])
return
x
,
x_lengths
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
=
None
):
dict_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"dict_file"
]
)
copyfile
(
self
.
dict_file
,
dict_file
)
return
(
dict_file
,)
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
7cabc0cd
...
...
@@ -5,9 +5,13 @@ import math
import
torch
from
torch
import
nn
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
.grad_tts_utils
import
GradTTSTokenizer
# flake8: noqa
def
sequence_mask
(
length
,
max_length
=
None
):
if
max_length
is
None
:
...
...
@@ -414,3 +418,63 @@ class TextEncoder(ModelMixin, ConfigMixin):
logw
=
self
.
proj_w
(
x_dp
,
x_mask
)
return
mu
,
logw
,
x_mask
class
GradTTS
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
text_encoder
,
noise_scheduler
,
tokenizer
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
text_encoder
=
text_encoder
,
noise_scheduler
=
noise_scheduler
,
tokenizer
=
tokenizer
)
@
torch
.
no_grad
()
def
__call__
(
self
,
text
,
num_inference_steps
=
50
,
temperature
=
1.3
,
length_scale
=
0.91
,
speaker_id
=
15
,
torch_device
=
None
):
if
torch_device
is
None
:
torch_device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
x
,
x_lengths
=
self
.
tokenizer
(
text
)
x
=
x
.
to
(
torch_device
)
x_lengths
=
x_lengths
.
to
(
torch_device
)
if
speaker_id
is
not
None
:
speaker_id
=
torch
.
LongTensor
([
speaker_id
]).
to
(
torch_device
)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x
,
logw
,
x_mask
=
self
.
text_encoder
(
x
,
x_lengths
)
w
=
torch
.
exp
(
logw
)
*
x_mask
w_ceil
=
torch
.
ceil
(
w
)
*
length_scale
y_lengths
=
torch
.
clamp_min
(
torch
.
sum
(
w_ceil
,
[
1
,
2
]),
1
).
long
()
y_max_length
=
int
(
y_lengths
.
max
())
y_max_length_
=
fix_len_compatibility
(
y_max_length
)
# Using obtained durations `w` construct alignment map `attn`
y_mask
=
sequence_mask
(
y_lengths
,
y_max_length_
).
unsqueeze
(
1
).
to
(
x_mask
.
dtype
)
attn_mask
=
x_mask
.
unsqueeze
(
-
1
)
*
y_mask
.
unsqueeze
(
2
)
attn
=
generate_path
(
w_ceil
.
squeeze
(
1
),
attn_mask
.
squeeze
(
1
)).
unsqueeze
(
1
)
# Align encoded text and get mu_y
mu_y
=
torch
.
matmul
(
attn
.
squeeze
(
1
).
transpose
(
1
,
2
),
mu_x
.
transpose
(
1
,
2
))
mu_y
=
mu_y
.
transpose
(
1
,
2
)
# Sample latent representation from terminal distribution N(mu_y, I)
z
=
mu_y
+
torch
.
randn_like
(
mu_y
,
device
=
mu_y
.
device
)
/
temperature
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
y_mask
,
mu_y
,
t
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
xt
=
xt
*
y_mask
return
xt
[:,
:,
:
y_max_length
]
src/diffusers/schedulers/__init__.py
View file @
7cabc0cd
...
...
@@ -19,5 +19,6 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
src/diffusers/schedulers/scheduling_grad_tts.py
0 → 100644
View file @
7cabc0cd
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
)
self
.
timesteps
=
int
(
timesteps
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
sample_noise
(
self
,
timestep
):
noise
=
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
timestep
return
noise
def
step
(
self
,
xt
,
residual
,
mu
,
h
,
timestep
):
noise_t
=
self
.
sample_noise
(
timestep
)
dxt
=
0.5
*
(
mu
-
xt
-
residual
)
dxt
=
dxt
*
noise_t
*
h
xt
=
xt
-
dxt
return
xt
def
__len__
(
self
):
return
self
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment