Commit 71ecc7ae authored by patil-suraj's avatar patil-suraj
Browse files

add speaker emb in unet

parent 3f2d46a1
......@@ -154,6 +154,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.pe_scale = pe_scale
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
torch.nn.Linear(spk_emb_dim * 4, n_feats))
self.time_pos_emb = SinusoidalPosEmb(dim)
......@@ -189,6 +190,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None):
if self.n_spks > 1:
# Get speaker embedding
spk = self.spk_emb(spk)
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment