Commit f941fc99 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor tts sampler a bit

parent 4fbf8c81
...@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE # END OF THE CLIP MODEL COPY-PASTE
##################### #####################
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
""" """
Extract values from a 1-D numpy array for a batch of indices. Extract values from a 1-D numpy array for a batch of indices.
......
...@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline): ...@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline):
xt = z * y_mask xt = z * y_mask
h = 1.0 / num_inference_steps h = 1.0 / num_inference_steps
# (Patrick: TODO)
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
t_new = num_inference_steps - t - 1
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, t, mu_y, y_mask, speaker_id) residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) scheduler_residual = residual - mu_y + xt
xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps)
xt = xt * y_mask xt = xt * y_mask
return xt[:, :, :y_max_length] return xt[:, :, :y_max_length]
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
...@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin ...@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
class GradTTSScheduler(SchedulerMixin, ConfigMixin): class GradTTSScheduler(SchedulerMixin, ConfigMixin):
def __init__( def __init__(
self, self,
timesteps=1000, beta_start=0.05,
beta_start=0.0001, beta_end=20,
beta_end=0.02,
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
) )
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
self.betas = None
def get_timesteps(self, num_inference_steps):
return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)])
def set_betas(self, num_inference_steps):
timesteps = self.get_timesteps(num_inference_steps)
self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps])
def step(self, residual, sample, t, num_inference_steps):
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
if self.betas is None:
self.set_betas(num_inference_steps)
def sample_noise(self, timestep): beta_t = self.betas[t]
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep beta_t_deriv = beta_t / num_inference_steps
return noise
def step(self, xt, residual, mu, h, timestep): sample_deriv = residual * beta_t_deriv / 2
noise_t = self.sample_noise(timestep)
dxt = 0.5 * (mu - xt - residual)
dxt = dxt * noise_t * h
xt = xt - dxt
return xt
def __len__(self): sample = sample + sample_deriv
return len(self.config.timesteps) return sample
...@@ -31,6 +31,7 @@ from diffusers import ( ...@@ -31,6 +31,7 @@ from diffusers import (
GlideSuperResUNetModel, GlideSuperResUNetModel,
GlideTextToImageUNetModel, GlideTextToImageUNetModel,
GradTTSPipeline, GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline, LatentDiffusionPipeline,
PNDMPipeline, PNDMPipeline,
PNDMScheduler, PNDMScheduler,
...@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase):
def test_grad_tts(self): def test_grad_tts(self):
model_id = "fusing/grad-tts-libri-tts" model_id = "fusing/grad-tts-libri-tts"
grad_tts = GradTTSPipeline.from_pretrained(model_id) grad_tts = GradTTSPipeline.from_pretrained(model_id)
noise_scheduler = GradTTSScheduler()
grad_tts.noise_scheduler = noise_scheduler
text = "Hello world, I missed you so much." text = "Hello world, I missed you so much."
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment