Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f941fc99
Commit
f941fc99
authored
Jun 22, 2022
by
Patrick von Platen
Browse files
refactor tts sampler a bit
parent
4fbf8c81
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
17 deletions
+30
-17
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+1
-0
src/diffusers/pipelines/pipeline_grad_tts.py
src/diffusers/pipelines/pipeline_grad_tts.py
+4
-2
src/diffusers/schedulers/scheduling_grad_tts.py
src/diffusers/schedulers/scheduling_grad_tts.py
+22
-15
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+3
-0
No files found.
src/diffusers/pipelines/pipeline_glide.py
View file @
f941fc99
...
...
@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE
#####################
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
"""
Extract values from a 1-D numpy array for a batch of indices.
...
...
src/diffusers/pipelines/pipeline_grad_tts.py
View file @
f941fc99
...
...
@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline):
xt
=
z
*
y_mask
h
=
1.0
/
num_inference_steps
# (Patrick: TODO)
for
t
in
tqdm
.
tqdm
(
range
(
num_inference_steps
),
total
=
num_inference_steps
):
t_new
=
num_inference_steps
-
t
-
1
t
=
(
1.0
-
(
t
+
0.5
)
*
h
)
*
torch
.
ones
(
z
.
shape
[
0
],
dtype
=
z
.
dtype
,
device
=
z
.
device
)
time
=
t
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
residual
=
self
.
unet
(
xt
,
t
,
mu_y
,
y_mask
,
speaker_id
)
xt
=
self
.
noise_scheduler
.
step
(
xt
,
residual
,
mu_y
,
h
,
time
)
scheduler_residual
=
residual
-
mu_y
+
xt
xt
=
self
.
noise_scheduler
.
step
(
scheduler_residual
,
xt
,
t_new
,
num_inference_steps
)
xt
=
xt
*
y_mask
return
xt
[:,
:,
:
y_max_length
]
src/diffusers/schedulers/scheduling_grad_tts.py
View file @
f941fc99
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
class
GradTTSScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_start
=
0.05
,
beta_end
=
20
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register_to_config
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
)
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
betas
=
None
def
get_timesteps
(
self
,
num_inference_steps
):
return
np
.
array
([(
t
+
0.5
)
/
num_inference_steps
for
t
in
range
(
num_inference_steps
)])
def
set_betas
(
self
,
num_inference_steps
):
timesteps
=
self
.
get_timesteps
(
num_inference_steps
)
self
.
betas
=
np
.
array
([
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
t
for
t
in
timesteps
])
def
step
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
if
self
.
betas
is
None
:
self
.
set_betas
(
num_inference_steps
)
def
sample_noise
(
self
,
timestep
):
noise
=
self
.
beta_start
+
(
self
.
beta_end
-
self
.
beta_start
)
*
timestep
return
noise
beta_t
=
self
.
betas
[
t
]
beta_t_deriv
=
beta_t
/
num_inference_steps
def
step
(
self
,
xt
,
residual
,
mu
,
h
,
timestep
):
noise_t
=
self
.
sample_noise
(
timestep
)
dxt
=
0.5
*
(
mu
-
xt
-
residual
)
dxt
=
dxt
*
noise_t
*
h
xt
=
xt
-
dxt
return
xt
sample_deriv
=
residual
*
beta_t_deriv
/
2
def
__len__
(
self
):
return
len
(
self
.
config
.
timesteps
)
sample
=
sample
+
sample_deriv
return
sample
tests/test_modeling_utils.py
View file @
f941fc99
...
...
@@ -31,6 +31,7 @@ from diffusers import (
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GradTTSPipeline
,
GradTTSScheduler
,
LatentDiffusionPipeline
,
PNDMPipeline
,
PNDMScheduler
,
...
...
@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase):
def
test_grad_tts
(
self
):
model_id
=
"fusing/grad-tts-libri-tts"
grad_tts
=
GradTTSPipeline
.
from_pretrained
(
model_id
)
noise_scheduler
=
GradTTSScheduler
()
grad_tts
.
noise_scheduler
=
noise_scheduler
text
=
"Hello world, I missed you so much."
generator
=
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment