Commit 7b55d334 authored by patil-suraj's avatar patil-suraj
Browse files

being pipeline

parent 986cc9b2
...@@ -7,6 +7,9 @@ from torch import nn ...@@ -7,6 +7,9 @@ from torch import nn
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
from diffusers import DiffusionPipeline
from .grad_tts_utils import text_to_sequence
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
...@@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin): ...@@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin):
logw = self.proj_w(x_dp, x_mask) logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask return mu, logw, x_mask
class GradTTS(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pass
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment