Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
f941fc99
Commit
f941fc99
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
refactor tts sampler a bit
parent
4fbf8c81
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
17 deletions
+30
-17
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+1
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+4
-2
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+22
-15
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-0
No files found.
src/diffusers/pipelines/pipeline_glide.py
View file @
f941fc99
...
@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
...
@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE
# END OF THE CLIP MODEL COPY-PASTE
#####################
#####################
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
"""
"""
Extract values from a 1-D numpy array for a batch of indices.
Extract values from a 1-D numpy array for a batch of indices.
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
f941fc99
...
@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline):
...
@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline):
xt
=
z
*
y_mask
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
h
=
1.0
/
num_inference_steps
# (Patrick: TODO)
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t_new
=
num_inference_steps
-
t
-
1
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
t
,
mu_y
,
y_mask
,
speaker_id
)
residual
=
self
.
unet
(
xt
,
t
,
mu_y
,
y_mask
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
scheduler_residual
=
residual
-
mu_y
+
xt
xt
=
self
.
noise_scheduler
.
step
(
scheduler_residual
,
xt
,
t_new
,
num_inference_steps
)
xt
=
xt
*
y_mask
xt
=
xt
*
y_mask
return
xt
[:,
:,
:
y_max_length
]
return
xt
[:,
:,
:
y_max_length
]
src/diffusers/schedulers/scheduling_grad_tts.py
View file @
f941fc99
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
...
@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
timesteps
=
1000
,
beta_start
=
0.05
,
beta_start
=
0.0001
,
beta_end
=
20
,
beta_end
=
0.02
,
tensor_format
=
"np"
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_to_config
(
self
.
register_to_config
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_end
=
beta_end
,
)
)
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
betas
=
None
def
get_timesteps
(
self
,
num_inference_steps
):
return
np
.
array
([(
t
+
0.5
)
/
num_inference_steps
for
t
in
range
(
num_inference_steps
)])
def
set_betas
(
self
,
num_inference_steps
):
timesteps
=
self
.
get_timesteps
(
num_inference_steps
)
self
.
betas
=
np
.
array
([
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
t
for
t
in
timesteps
])
def
step
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
if
self
.
betas
is
None
:
self
.
set_betas
(
num_inference_steps
)
def
sample_noise
(
self
,
timestep
):
beta_t
=
self
.
betas
[
t
]
noise
=
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
timestep
beta_t_deriv
=
beta_t
/
num_inference_steps
return
noise
def
step
(
self
,
xt
,
residual
,
mu
,
h
,
timestep
):
sample_deriv
=
residual
*
beta_t_deriv
/
2
noise_t
=
self
.
sample_noise
(
timestep
)
dxt
=
0.5
*
(
mu
-
xt
-
residual
)
dxt
=
dxt
*
noise_t
*
h
xt
=
xt
-
dxt
return
xt
def
__len__
(
self
):
sample
=
sample
+
sample_deriv
return
len
(
self
.
config
.
timesteps
)
return
sample
tests/test_modeling_utils.py
View file @
f941fc99
...
@@ -31,6 +31,7 @@ from diffusers import (
...
@@ -31,6 +31,7 @@ from diffusers import (
GlideSuperResUNetModel
,
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSPipeline
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
LatentDiffusionPipeline
,
PNDMPipeline
,
PNDMPipeline
,
PNDMScheduler
,
PNDMScheduler
,
...
@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_grad_tts
(
self
):
def
test_grad_tts
(
self
):
model_id
=
"fusing/grad-tts-libri-tts"
model_id
=
"fusing/grad-tts-libri-tts"
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
noise_scheduler
=
GradTTSScheduler
()
grad_tts
.
noise_scheduler
=
noise_scheduler
text
=
"Hello world, I missed you so much."
text
=
"Hello world, I missed you so much."
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment