Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
f941fc99
Commit
f941fc99
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
refactor tts sampler a bit
parent
4fbf8c81
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
17 deletions
+30
-17
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+1
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+4
-2
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+22
-15
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-0
No files found.
src/diffusers/pipelines/pipeline_glide.py
View file @
f941fc99
...
...
@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE
#####################
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
"""
Extract values from a 1-D numpy array for a batch of indices.
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
f941fc99
...
...
@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline):
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
# (Patrick: TODO)
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t_new
=
num_inference_steps
-
t
-
1
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
t
,
mu_y
,
y_mask
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
scheduler_residual
=
residual
-
mu_y
+
xt
xt
=
self
.
noise_scheduler
.
step
(
scheduler_residual
,
xt
,
t_new
,
num_inference_steps
)
xt
=
xt
*
y_mask
return
xt
[:,
:,
:
y_max_length
]
src/diffusers/schedulers/scheduling_grad_tts.py
View file @
f941fc99
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_start
=
0.05
,
beta_end
=
20
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
)
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
betas
=
None
def
get_timesteps
(
self
,
num_inference_steps
):
return
np
.
array
([(
t
+
0.5
)
/
num_inference_steps
for
t
in
range
(
num_inference_steps
)])
def
set_betas
(
self
,
num_inference_steps
):
timesteps
=
self
.
get_timesteps
(
num_inference_steps
)
self
.
betas
=
np
.
array
([
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
t
for
t
in
timesteps
])
def
step
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
if
self
.
betas
is
None
:
self
.
set_betas
(
num_inference_steps
)
def
sample_noise
(
self
,
timestep
):
noise
=
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
timestep
return
noise
beta_t
=
self
.
betas
[
t
]
beta_t_deriv
=
beta_t
/
num_inference_steps
def
step
(
self
,
xt
,
residual
,
mu
,
h
,
timestep
):
noise_t
=
self
.
sample_noise
(
timestep
)
dxt
=
0.5
*
(
mu
-
xt
-
residual
)
dxt
=
dxt
*
noise_t
*
h
xt
=
xt
-
dxt
return
xt
sample_deriv
=
residual
*
beta_t_deriv
/
2
def
__len__
(
self
):
return
len
(
self
.
config
.
timesteps
)
sample
=
sample
+
sample_deriv
return
sample
tests/test_modeling_utils.py
View file @
f941fc99
...
...
@@ -31,6 +31,7 @@ from diffusers import (
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
PNDMPipeline
,
PNDMScheduler
,
...
...
@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_grad_tts
(
self
):
model_id
=
"fusing/grad-tts-libri-tts"
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
noise_scheduler
=
GradTTSScheduler
()
grad_tts
.
noise_scheduler
=
noise_scheduler
text
=
"Hello world, I missed you so much."
generator
=
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment