"docs/vscode:/vscode.git/clone" did not exist on "f19a9204cdbd6b8360f85f404463aa45af9ee00b"
Commit e6e33f1a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2698 canceled with stages
# 模型编码
modelCode=1516
# 模型名称
modelName=MegaTTS3_pytorch
# 模型描述
modelDescription=骨干网络仅含0.45B参数,支持口音强度控制,适于实时语音交互,能满足不同场景下对语音口音克隆的多样化需求。
# 应用场景
appScenario=推理,语音合成,广媒,影视,动漫,医疗,家居,教育
# 框架类型
frameType=pytorch
<div align="center">
<h1>
MegaTTS 3 <img src="./assets/fig/Hi.gif" width="40px">
</h1>
<p>
Official PyTorch Implementation<br>
</p>
</div>
<div align="center">
<a href="https://huggingface.co/spaces/ByteDance/MegaTTS3"><img src="https://img.shields.io/badge/Hugging%20Face-Space%20Demo-yellow" alt="Hugging Face"></a>
<a href="#"><img src="https://img.shields.io/badge/Platform-linux-lightgrey" alt="version"></a>
<a href="#"><img src="https://img.shields.io/badge/Python-3.10-brightgreen" alt="version"></a>
<a href="#"><img src="https://img.shields.io/badge/PyTorch-2.3.0-orange" alt="python"></a>
<a href="#"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="mit"></a>
</div>
<div align="center">
<img src="https://img.shields.io/badge/Bytedance-%230077B5.svg?&style=flat-square&logo=bytedance&logoColor=white" />
<img src="https://img.shields.io/badge/Zhejiang University-%230077B5.svg?&style=flat-square&logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA1MTIgNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZmlsbD0iI2ZmZmZmZiIgZD0iTTI0My40IDIuNmwtMjI0IDk2Yy0xNCA2LTIxLjggMjEtMTguNyAzNS44UzE2LjggMTYwIDMyIDE2MGwwIDhjMCAxMy4zIDEwLjcgMjQgMjQgMjRsNDAwIDBjMTMuMyAwIDI0LTEwLjcgMjQtMjRsMC04YzE1LjIgMCAyOC4zLTEwLjcgMzEuMy0yNS42cy00LjgtMjkuOS0xOC43LTM1LjhsLTIyNC05NmMtOC0zLjQtMTcuMi0zLjQtMjUuMiAwek0xMjggMjI0bC02NCAwIDAgMTk2LjNjLS42IC4zLTEuMiAuNy0xLjggMS4xbC00OCAzMmMtMTEuNyA3LjgtMTcgMjIuNC0xMi45IDM1LjlTMTcuOSA1MTIgMzIgNTEybDQ0OCAwYzE0LjEgMCAyNi41LTkuMiAzMC42LTIyLjdzLTEuMS0yOC4xLTEyLjktMzUuOWwtNDgtMzJjLS42LS40LTEuMi0uNy0xLjgtMS4xTDQ0OCAyMjRsLTY0IDAgMCAxOTItNDAgMCAwLTE5Mi02NCAwIDAgMTkyLTQ4IDAgMC0xOTItNjQgMCAwIDE5Mi00MCAwIDAtMTkyek0yNTYgNjRhMzIgMzIgMCAxIDEgMCA2NCAzMiAzMiAwIDEgMSAwLTY0eiIvPjwvc3ZnPg==&logoColor=white" />
</div>
## Key features
- 🚀**Lightweight and Efficient:** The backbone of the TTS Diffusion Transformer has only 0.45B parameters.
- 🎧**Ultra High-Quality Voice Cloning:** You can try our model at [Huggingface Demo](https://huggingface.co/spaces/ByteDance/MegaTTS3)🎉. The .wav and .npy files can be found at [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing). Submit a sample (.wav format, < 24s, and please do not contain space in filename) on [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) to receive .npy voice latents you can use locally.
- 🌍**Bilingual Support:** Supports both Chinese and English, and code-switching.
- ✍️**Controllable:** Supports accent intensity control ✅ and fine-grained pronunciation/duration adjustment (coming soon).
[MegaTTS 3 Demo Video](https://github.com/user-attachments/assets/0174c111-f392-4376-a34b-0b5b8164aacc)
<div style='width:100%;text-align:center'>
<img src="./assets/fig/table_tts.png" width="550px">
</div>
## 🎯Roadmap
- **[2025-03-22]** Our project has been released!
## Installation
``` sh
# Clone the repository
git clone https://github.com/bytedance/MegaTTS3
cd MegaTTS3
```
**Requirements (for Linux)**
``` sh
# Create a python 3.10 conda env (you could also use virtualenv)
conda create -n megatts3-env python=3.10
conda activate megatts3-env
pip install -r requirements.txt
# Set the root directory
export PYTHONPATH="/path/to/MegaTTS3:$PYTHONPATH"
# [Optional] Set GPU
export CUDA_VISIBLE_DEVICES=0
# If you encounter bugs with pydantic in inference, you should check if the versions of pydantic and gradio are matched.
# [Note] if you encounter bugs related with httpx, please check that whether your environmental variable "no_proxy" has patterns like "::"
```
**Requirements (for Windows)**
``` sh
# [The Windows version is currently under testing]
# Comment below dependence in requirements.txt:
# # WeTextProcessing==1.0.4.1
# Create a python 3.10 conda env (you could also use virtualenv)
conda create -n megatts3-env python=3.10
conda activate megatts3-env
pip install -r requirements.txt
conda install -y -c conda-forge pynini==2.1.5
pip install WeTextProcessing==1.0.3
# [Optional] If you want GPU inference, you may need to install specific version of PyTorch for your GPU from https://pytorch.org/.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
# [Note] if you encounter bugs related with `ffprobe` or `ffmpeg`, you can install it through `conda install -c conda-forge ffmpeg`
# Set environment variable for root directory
set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Windows
$env:PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # Powershell on Windows
conda env config vars set PYTHONPATH="C:\path\to\MegaTTS3;%PYTHONPATH%" # For conda users
# [Optional] Set GPU
set CUDA_VISIBLE_DEVICES=0 # Windows
$env:CUDA_VISIBLE_DEVICES=0 # Powershell on Windows
```
**Requirements (for Docker)**
``` sh
# [The Docker version is currently under testing]
# ! You should download the pretrained checkpoint before running the following command
docker build . -t megatts3:latest
# For GPU inference
docker run -it -p 7929:7929 --gpus all -e CUDA_VISIBLE_DEVICES=0 megatts3:latest
# For CPU inference
docker run -it -p 7929:7929 megatts3:latest
# Visit http://0.0.0.0:7929/ for gradio.
```
**Model Download**
The pretrained checkpoint can be found at [Google Drive](https://drive.google.com/drive/folders/1CidiSqtHgJTBDAHQ746_on_YR0boHDYB?usp=sharing) or [Huggingface](https://huggingface.co/ByteDance/MegaTTS3). Please download them and put them to ``./checkpoints/xxx``.
> [!IMPORTANT]
> For security issues, we do not upload the parameters of WaveVAE encoder to the above links. You can only use the pre-extracted latents from [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) for inference. If you want to synthesize speech for speaker A, you need "A.wav" and "A.npy" in the same directory. If you have any questions or suggestions for our model, please email us.
>
> This project is primarily intended for academic purposes. For academic datasets requiring evaluation, you may upload them to the voice request queue in [link2](https://drive.google.com/drive/folders/1gCWL1y_2xu9nIFhUX_OW5MbcFuB7J5Cl?usp=sharing) (within 24s for each clip). After verifying that your uploaded voices are free from safety issues, we will upload their latent files to [link1](https://drive.google.com/drive/folders/1QhcHWcy20JfqWjgqZX1YM3I6i9u4oNlr?usp=sharing) as soon as possible.
>
> In the coming days, we will also prepare and release the latent representations for some common TTS benchmarks.
## Inference
**Command-Line Usage (Standard)**
``` bash
# p_w (intelligibility weight), t_w (similarity weight). Typically, prompt with more noises requires higher p_w and t_w
python tts/infer_cli.py --input_wav 'assets/Chinese_prompt.wav' --input_text "另一边的桌上,一位读书人嗤之以鼻道,'佛子三藏,神子燕小鱼是什么样的人物,李家的那个李子夜如何与他们相提并论?'" --output_dir ./gen
# As long as audio volume and pronunciation are appropriate, increasing --t_w within reasonable ranges (2.0~5.0)
# will increase the generated speech's expressiveness and similarity (especially for some emotional cases).
python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text 'As his long promised tariff threat turned into reality this week, top human advisers began fielding a wave of calls from business leaders, particularly in the automotive sector, along with lawmakers who were sounding the alarm.' --output_dir ./gen --p_w 2.0 --t_w 3.0
```
**Command-Line Usage (for TTS with Accents)**
``` bash
# When p_w (intelligibility weight) ≈ 1.0, the generated audio closely retains the speaker’s original accent. As p_w increases, it shifts toward standard pronunciation.
# t_w (similarity weight) is typically set 0–3 points higher than p_w for optimal results.
# Useful for accented TTS or solving the accent problems in cross-lingual TTS.
python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这是一条有口音的音频。' --output_dir ./gen --p_w 1.0 --t_w 3.0
python tts/infer_cli.py --input_wav 'assets/English_prompt.wav' --input_text '这条音频的发音标准一些了吗?' --output_dir ./gen --p_w 2.5 --t_w 2.5
```
**Web UI Usage**
``` bash
# We also support cpu inference, but it may take about 30 seconds (for 10 inference steps).
python tts/gradio_api.py
```
## Submodules
> [!TIP]
> In addition to TTS, some submodules in this project may also have additional usages.
> See ``./tts/frontend_fuction.py`` and ``./tts/infer_cli.py`` for example code.
### Aligner
**Description:** a robust speech-text aligner model trained using pseudo-labels generated by a large number of MFA expert models.
**Usage**: 1) Prepare the finetuning dataset for our model; 2) Filter the large-scale speech dataset (if the aligner fails to align a certain speech clip, it is likely to be noisy); 3) Phoneme recognition; 4) Speech segmentation.
### Graphme-to-Phoneme Model
**Description:** a Qwen2.5-0.5B model finetuned for robust graphme-to-phoneme conversion.
**Usage**: Graphme-to-phoneme conversion.
### WaveVAE
**Description:** a strong waveform VAE that can compress 24 kHz speeche into 25 Hz acoustic latent and reconstruct the original wave almost losslessly.
**Usage:** 1) Acoustic latents can provide a more compact and discriminative training target for speech synthesis models compared to mel-spectrograms, accelerating convergence; 2) Used as acoustic latents for voice conversion; 3) High-quality vocoder.
<div style='width:100%;text-align:center'>
<img src="./assets/fig/table_wavvae.png" width="650px">
</div>
## Security
If you discover a potential security issue in this project, or think you may
have discovered a security issue, we ask that you notify Bytedance Security via our [security center](https://security.bytedance.com/src) or [sec@bytedance.com](sec@bytedance.com).
Please do **not** create a public GitHub issue.
## License
This project is licensed under the [Apache-2.0 License](LICENSE).
## Citation
This repo contains forced-align version of `Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis` and the WavVAE is mainly based on `Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling`. Compared to the model described in paper, the repository includes additional models. These models not only enhance the stability and cloning capabilities of the algorithm but can also be independently utilized to serve a wider range of scenarios.
```
@article{jiang2025sparse,
title={Sparse Alignment Enhanced Latent Diffusion Transformer for Zero-Shot Speech Synthesis},
author={Jiang, Ziyue and Ren, Yi and Li, Ruiqi and Ji, Shengpeng and Ye, Zhenhui and Zhang, Chen and Jionghao, Bai and Yang, Xiaoda and Zuo, Jialong and Zhang, Yu and others},
journal={arXiv preprint arXiv:2502.18924},
year={2025}
}
@article{ji2024wavtokenizer,
title={Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling},
author={Ji, Shengpeng and Jiang, Ziyue and Wang, Wen and Chen, Yifu and Fang, Minghui and Zuo, Jialong and Yang, Qian and Cheng, Xize and Wang, Zehan and Li, Ruiqi and others},
journal={arXiv preprint arXiv:2408.16532},
year={2024}
}
```
# torch==2.3.0
# torchaudio==2.3.0
numpy<2
setproctitle==1.3.3
attrdict==2.0.1
librosa==0.10.2.post1
langdetect==1.0.9
pydub==0.25.1
pyloudnorm==0.1.1
modelscope==1.22.2
WeTextProcessing==1.0.4.1
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
x-transformers==1.44.4
torchdiffeq==0.2.5
openai-whisper==20240930
httpx==0.28.1
gradio==5.23.1
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import whisper
import librosa
from copy import deepcopy
from tts.utils.text_utils.ph_tone_convert import split_ph_timestamp, split_ph
from tts.utils.audio_utils.align import mel2token_to_dur
''' Graphme to phoneme function '''
def g2p(self, text_inp):
# prepare inputs
txt_token = self.g2p_tokenizer('<BOT>' + text_inp + '<BOS>')['input_ids']
input_ids = torch.LongTensor([txt_token+[145+self.speech_start_idx]]).to(self.device)
# model forward
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
outputs = self.g2p_model.generate(input_ids, max_new_tokens=256, do_sample=True, top_k=1, eos_token_id=800+1+self.speech_start_idx)
# process outputs
ph_tokens = outputs[:, len(txt_token):-1]-self.speech_start_idx
ph_pred, tone_pred = split_ph(ph_tokens[0])
ph_pred, tone_pred = ph_pred[None, :].to(self.device), tone_pred[None, :].to(self.device)
return ph_pred, tone_pred
''' Get phoneme2mel align of prompt speech '''
def align(self, wav):
with torch.inference_mode():
whisper_wav = librosa.resample(wav, orig_sr=self.sr, target_sr=16000)
mel = torch.FloatTensor(whisper.log_mel_spectrogram(whisper_wav).T).to(self.device)[None].transpose(1,2)
prompt_max_frame = mel.size(2) // self.fm * self.fm
mel = mel[:, :, :prompt_max_frame]
token = torch.LongTensor([[798]]).to(self.device)
audio_features = self.aligner_lm.embed_audio(mel)
for i in range(768):
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
logits = self.aligner_lm.logits(token, audio_features, None)
token_pred = torch.argmax(F.softmax(logits[:, -1], dim=-1), 1)[None]
token = torch.cat([token, token_pred], dim=1)
if token_pred[0] == 799:
break
alignment_tokens = token
ph_ref, tone_ref, dur_ref, _ = split_ph_timestamp(deepcopy(alignment_tokens)[0, 1:-1])
ph_ref = torch.Tensor(ph_ref)[None].to(self.device)
tone_ref = torch.Tensor(tone_ref)[None].to(self.device)
if dur_ref.sum() < prompt_max_frame:
dur_ref[-1] += prompt_max_frame - dur_ref.sum()
elif dur_ref.sum() > prompt_max_frame:
len_diff = dur_ref.sum() - prompt_max_frame
while True:
for i in range(len(dur_ref)):
dur_ref[i] -= 1
len_diff -= 1
if len_diff == 0:
break
if len_diff == 0:
break
mel2ph_ref = self.length_regulator(dur_ref[None]).to(self.device)
mel2ph_ref = mel2ph_ref[:, :mel2ph_ref.size(1)//self.fm*self.fm]
return ph_ref, tone_ref, mel2ph_ref
''' Duration Prompting '''
def make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref):
dur_tokens_2d_ = mel2token_to_dur(mel2ph_ref, ph_ref.shape[1]).clamp(
max=self.hp_dur_model['dur_code_size'] - 1) + 1
ctx_dur_tokens = dur_tokens_2d_.clone().flatten(0, 1).to(self.device)
txt_tokens_flat_ = ph_ref.flatten(0, 1)
ctx_dur_tokens = ctx_dur_tokens[txt_tokens_flat_ > 0][None]
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
dur_spk_pos_ids_flat = range(0, last_dur_pos_prompt)
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
_, incremental_state_dur_prompt = self.dur_model.infer(
ph_ref, {'tone': tone_ref}, None, None, None,
ctx_vqcodes=ctx_dur_tokens, spk_pos_ids_flat=dur_spk_pos_ids_flat, return_state=True)
return incremental_state_dur_prompt, ctx_dur_tokens
''' Duration Prediction '''
def dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first, is_final):
last_dur_token = ctx_dur_tokens[:, -1:]
last_dur_pos_prompt = ctx_dur_tokens.shape[1]
incremental_state_dur = deepcopy(incremental_state_dur_prompt)
txt_len = ph_pred.shape[1]
dur_spk_pos_ids_flat = range(last_dur_pos_prompt, last_dur_pos_prompt + txt_len)
dur_spk_pos_ids_flat = torch.LongTensor([dur_spk_pos_ids_flat]).to(self.device)
last_dur_pos_prompt = last_dur_pos_prompt + txt_len
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
dur_pred = self.dur_model.infer(
ph_pred, {'tone': tone_pred}, None, None, None,
incremental_state=incremental_state_dur,
first_decoder_inp=last_dur_token,
spk_pos_ids_flat=dur_spk_pos_ids_flat,
)
dur_pred = dur_pred - 1
dur_pred = dur_pred.clamp(0, self.hp_dur_model['dur_code_size'] - 1)
# if is_final:
# dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
# else:
# dur_pred[:, -1] = dur_pred[:, -1].clamp(48, 128)
# if seg_i > 0:
# dur_pred[:, 0] = 0
# ['。', '!', '?', 'sil']
# for sil_token in [148, 153, 166, 145]:
# dur_pred[ph_pred==sil_token].clamp_min(32)
# # [',', ';']
# for sil_token in [163, 165]:
# dur_pred[ph_pred==sil_token].clamp_min(16)
if not is_final:
# add 0.32ms for crossfade
dur_pred[:, -1] = dur_pred[:, -1] + 32
else:
dur_pred[:, -1] = dur_pred[:, -1].clamp(64, 128)
''' DiT target speech generation '''
dur_disturb_choice = (torch.rand_like(dur_pred.float()) > 0.5).float()
dur_disturb_r = 1 + torch.rand_like(dur_pred.float()) * dur_disturb
dur_pred = dur_pred * dur_disturb_r * dur_disturb_choice + \
dur_pred / dur_disturb_r * (1 - dur_disturb_choice)
dur_pred = torch.round(dur_pred * dur_alpha).clamp(0, 127)
# ['。', '!', '?', 'sil']
for sil_token in [148, 153, 166, 145]:
dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(64)
# [',', ';']
for sil_token in [163, 165]:
dur_pred[ph_pred==sil_token] = dur_pred[ph_pred==sil_token].clamp_min(32)
if is_first:
dur_pred[:, 0] = 8
dur_sum = dur_pred.sum()
npad = self.fm - dur_sum % self.fm
if npad < self.fm:
dur_pred[:, -1] += npad
mel2ph_pred = self.length_regulator(dur_pred).to(self.device)
return mel2ph_pred
def prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent):
# Prepare duration token
mel2ph_pred = torch.cat((mel2ph_ref, mel2ph_pred+ph_ref.size(1)), dim=1)
mel2ph_pred = mel2ph_pred[:, :mel2ph_pred.size(1)//self.fm*self.fm].repeat(3, 1)
# Prepare phone and tone token
ph_pred = torch.cat((ph_ref, ph_pred), dim=1)
tone_pred = torch.cat((tone_ref, tone_pred), dim=1)
# Disable the English tone (set them to 3)"""
en_tone_idx = ~((tone_pred == 4) | ( (11 <= tone_pred) & (tone_pred <= 15)) | (tone_pred == 0))
tone_pred[en_tone_idx] = 3
# Prepare cfg inputs
ph_seq = torch.cat([ph_pred, ph_pred, torch.full(ph_pred.size(), self.cfg_mask_token_phone, device=self.device)], 0)
tone_seq = torch.cat([tone_pred, tone_pred, torch.full(tone_pred.size(), self.cfg_mask_token_tone, device=self.device)], 0)
target_size = mel2ph_pred.size(1)//self.vae_stride
vae_latent_ = vae_latent.repeat(3, 1, 1)
ctx_mask = torch.ones_like(vae_latent_[:, :, 0:1])
vae_latent_ = F.pad(vae_latent_, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
vae_latent_[1:] = 0.0
ctx_mask = F.pad(ctx_mask, (0, 0, 0, target_size - vae_latent.size(1)), mode='constant', value=0)
return {
'phone': ph_seq,
'tone': tone_seq,
"lat_ctx": vae_latent_ * ctx_mask,
"ctx_mask": ctx_mask,
"dur": mel2ph_pred,
}
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing as mp
import torch
import os
from functools import partial
import gradio as gr
import traceback
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
def model_worker(input_queue, output_queue, device_id):
device = None
if device_id is not None:
device = torch.device(f'cuda:{device_id}')
infer_pipe = MegaTTS3DiTInfer(device=device)
while True:
task = input_queue.get()
inp_audio_path, inp_npy_path, inp_text, infer_timestep, p_w, t_w = task
try:
convert_to_wav(inp_audio_path)
wav_path = os.path.splitext(inp_audio_path)[0] + '.wav'
cut_wav(wav_path, max_len=28)
with open(wav_path, 'rb') as file:
file_content = file.read()
resource_context = infer_pipe.preprocess(file_content, latent_file=inp_npy_path)
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
output_queue.put(wav_bytes)
except Exception as e:
traceback.print_exc()
print(task, str(e))
output_queue.put(None)
def main(inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w, processes, input_queue, output_queue):
print("Push task to the inp queue |", inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w)
input_queue.put((inp_audio, inp_npy, inp_text, infer_timestep, p_w, t_w))
res = output_queue.get()
if res is not None:
return res
else:
print("")
return None
if __name__ == '__main__':
mp.set_start_method('spawn', force=True)
mp_manager = mp.Manager()
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if devices != '':
devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
else:
devices = None
num_workers = 1
input_queue = mp_manager.Queue()
output_queue = mp_manager.Queue()
processes = []
print("Start open workers")
for i in range(num_workers):
p = mp.Process(target=model_worker, args=(input_queue, output_queue, i % len(devices) if devices is not None else None))
p.start()
processes.append(p)
api_interface = gr.Interface(fn=
partial(main, processes=processes, input_queue=input_queue,
output_queue=output_queue),
inputs=[gr.Audio(type="filepath", label="Upload .wav"), gr.File(type="filepath", label="Upload .npy"), "text",
gr.Number(label="infer timestep", value=32),
gr.Number(label="Intelligibility Weight", value=1.4),
gr.Number(label="Similarity Weight", value=3.0)], outputs=[gr.Audio(label="Synthesized Audio")],
title="MegaTTS3",
description="Upload a speech clip as a reference for timbre, " +
"upload the pre-extracted latent file, "+
"input the target text, and receive the cloned voice.", concurrency_limit=1)
api_interface.launch(server_name='0.0.0.0', server_port=7929, debug=True)
for p in processes:
p.join()
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import argparse
import librosa
import numpy as np
import torch
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
from langdetect import detect as classify_language
from pydub import AudioSegment
import pyloudnorm as pyln
from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments
from tts.utils.commons.ckpt_utils import load_ckpt
from tts.utils.commons.hparams import set_hparams, hparams
from tts.utils.text_utils.text_encoder import TokenTextEncoder
from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english, chunk_text_chinesev2
from tts.utils.commons.hparams import hparams, set_hparams
if "TOKENIZERS_PARALLELISM" not in os.environ:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def convert_to_wav(wav_path):
# Check if the file exists
if not os.path.exists(wav_path):
print(f"The file '{wav_path}' does not exist.")
return
# Check if the file already has a .wav extension
if not wav_path.endswith(".wav"):
# Define the output path with a .wav extension
out_path = os.path.splitext(wav_path)[0] + ".wav"
# Load the audio file using pydub and convert it to WAV
audio = AudioSegment.from_file(wav_path)
audio.export(out_path, format="wav")
print(f"Converted '{wav_path}' to '{out_path}'")
def cut_wav(wav_path, max_len=28):
audio = AudioSegment.from_file(wav_path)
audio = audio[:int(max_len * 1000)]
audio.export(wav_path, format="wav")
class MegaTTS3DiTInfer():
def __init__(
self,
device=None,
ckpt_root='./checkpoints',
dit_exp_name='diffusion_transformer',
frontend_exp_name='aligner_lm',
wavvae_exp_name='wavvae',
dur_ckpt_path='duration_lm',
g2p_exp_name='g2p',
precision=torch.float32,#torch.float16存在精度问题
**kwargs
):
self.sr = 24000
self.fm = 8
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.precision = precision
# build models
self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
self.build_model(self.device)
# init text normalizer
self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
self.en_normalizer = EnNormalizer(overwrite_cache=False)
# loudness meter
self.loudness_meter = pyln.Meter(self.sr)
def build_model(self, device):
set_hparams(exp_name=self.dit_exp_name, print_hparams=False)
''' Load Dict '''
current_dir = os.path.dirname(os.path.abspath(__file__))
ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig'))
self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']}
self.token_encoder = token_encoder = self.ling_dict['phone']
ph_dict_size = len(token_encoder)
''' Load Duration LM '''
from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
hp_dur_model['frames_multiple'] = hparams['frames_multiple']
self.dur_model = ARDurPredictor(
hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
hp_dur_model['dur_model_layers'], ph_dict_size,
hp_dur_model['dur_code_size'],
use_rot_embed=hp_dur_model.get('use_rot_embed', False))
self.length_regulator = LengthRegulator()
load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
self.dur_model.eval()
self.dur_model.to(device)
''' Load Diffusion Transformer '''
from tts.modules.llm_dit.dit import Diffusion
self.dit = Diffusion()
load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
self.dit.eval()
self.dit.to(device)
self.cfg_mask_token_phone = 302 - 1
self.cfg_mask_token_tone = 32 - 1
''' Load Frontend LM '''
from tts.modules.aligner.whisper_small import Whisper
self.aligner_lm = Whisper()
load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
self.aligner_lm.eval()
self.aligner_lm.to(device)
self.kv_cache = None
self.hooks = None
''' Load G2P LM'''
from transformers import AutoTokenizer, AutoModelForCausalLM
g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
g2p_tokenizer.padding_side = "right"
self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
self.g2p_tokenizer = g2p_tokenizer
self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0]
''' Wav VAE '''
self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3
self.wavvae = WavVAE_V3(hparams=hp_wavvae)
if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
self.has_vae_encoder = True
else:
load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
self.has_vae_encoder = False
self.wavvae.eval()
self.wavvae.to(device)
self.vae_stride = hp_wavvae.get('vae_stride', 4)
self.hop_size = hp_wavvae.get('hop_size', 4)
def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
wav_bytes = convert_to_wav_bytes(audio_bytes)
''' Load wav '''
wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
# Pad wav if necessary
ws = hparams['win_size']
if len(wav) % ws < ws - 1:
wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))
''' obtain alignments with aligner_lm '''
ph_ref, tone_ref, mel2ph_ref = align(self, wav)
with torch.inference_mode():
''' Forward WaveVAE to obtain: prompt latent '''
if self.has_vae_encoder:
wav = torch.FloatTensor(wav)[None].to(self.device)
vae_latent = self.wavvae.encode_latent(wav)
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
else:
assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode"
vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
''' Duration Prompting '''
self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)
return {
'ph_ref': ph_ref,
'tone_ref': tone_ref,
'mel2ph_ref': mel2ph_ref,
'vae_latent': vae_latent,
'incremental_state_dur_prompt': incremental_state_dur_prompt,
'ctx_dur_tokens': ctx_dur_tokens,
}
def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):
device = self.device
ph_ref = resource_context['ph_ref'].to(device)
tone_ref = resource_context['tone_ref'].to(device)
mel2ph_ref = resource_context['mel2ph_ref'].to(device)
vae_latent = resource_context['vae_latent'].to(device)
ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device)
incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt']
with torch.inference_mode():
''' Generating '''
wav_pred_ = []
language_type = classify_language(input_text)
if language_type == 'en':
input_text = self.en_normalizer.normalize(input_text)
text_segs = chunk_text_english(input_text, max_chars=130)
else:
input_text = self.zh_normalizer.normalize(input_text)
text_segs = chunk_text_chinesev2(input_text, limit=60)
for seg_i, text in enumerate(text_segs):
''' G2P '''
ph_pred, tone_pred = g2p(self, text)
''' Duration Prediction '''
mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent)
# Speech dit inference
with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
# WavVAE decode
x[:, :vae_latent.size(1)] = vae_latent
wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32)
''' Post-processing '''
# Trim prompt wav
wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
# Norm generated wav to prompt wav's level
meter = pyln.Meter(self.sr) # create BS.1770 meter
loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
if np.abs(wav_pred).max() >= 1:
wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95
# Apply hamming window
wav_pred_.append(wav_pred)
wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
return to_wav_bytes(wav_pred, self.sr)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_wav', type=str)
parser.add_argument('--input_text', type=str)
parser.add_argument('--output_dir', type=str)
parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer')
parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight')
parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight')
args = parser.parse_args()
wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w
infer_ins = MegaTTS3DiTInfer()
with open(wav_path, 'rb') as file:
file_content = file.read()
print(f"| Start processing {wav_path}+{input_text}")
resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy'))
wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w)
print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav")
os.makedirs(out_path, exist_ok=True)
save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav')
# MIT License
# Copyright (c) 2022 OpenAI
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2022] [OpenAI]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE.
# This modified file is released under the same license.
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
class MultiHeadAttention(nn.Module):
use_sdpa = True
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
casual: Optional[bool] = None
):
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache[self.key]
v = kv_cache[self.value]
wv = self.qkv_attention(q, k, v, mask, casual)
return self.out(wv)
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
a = scaled_dot_product_attention(
q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
return out
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
casual: Optional[bool] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual)
if self.cross_attn:
# TODO: Cross attention mask
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False)
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, attn_mask: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype)
for block in self.blocks:
x = block(x, mask=attn_mask, casual=False)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
self.out_proj = nn.Linear(n_state, n_vocab)
def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True)
x = self.ln(x)
# logits = (
# x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
# ).float()
logits = self.out_proj(x)
return logits
class Whisper(nn.Module):
def __init__(self):
super().__init__()
self.n_vocab = 6800
self.n_text_layer = 6
self.n_text_head = 8
self.n_text_ctx = 2048
self.encoder = AudioEncoder(
n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6,
)
self.decoder = TextDecoder(
n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6,
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel, None)
def logits(self, tokens, audio_features, kv_cache=None):
return self.decoder(tokens, None, audio_features, kv_cache=kv_cache)
def forward(
self, mel, mel_len, token, token_len
) -> Dict[str, torch.Tensor]:
attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0
attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0
return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc))
@property
def device(self):
return next(self.parameters()).device
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
def sequence_mask(self, seq_lens, max_len=None, device='cpu'):
b = seq_lens.shape[0]
if max_len is None:
max_len = seq_lens.max()
mask = torch.arange(max_len).unsqueeze(0).to(device) # [1, t]
mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t]
mask = mask.float()
return mask
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear
from tqdm import tqdm
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
FS_ENCODERS = {
'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
dict_size, hp['hidden_size'], hp['hidden_size'],
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
}
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def expand_states(h, mel2token):
h = F.pad(h, [0, 0, 1, 0])
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
h = torch.gather(h, 1, mel2token_) # [B, T, H]
return h
class CodePredictor(nn.Module):
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
super().__init__()
self.hparams = deepcopy(hparams)
self.hparams['hidden_size'] = hidden_size
self.hidden_size = hidden_size
char_dict_size = hparams.get('char_dict_size', 4000)
if not hparams.get('lm_use_enc'):
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
if hparams.get('mega_use_char', True):
self.char_encoder = nn.Embedding(char_dict_size,
self.hidden_size, padding_idx=0)
else:
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
if hparams.get('mega_use_char', True):
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
if hparams['use_ph_pos_embed']:
self.ph_pos_embed = PosEmb(self.hidden_size)
self.char_empty_embed = nn.Embedding(1, self.hidden_size)
if hparams.get('use_bert_input'):
self.bert_input_proj = nn.Linear(768, self.hidden_size)
self.ling_label_embed_layers = nn.ModuleDict()
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)
self.dec_hidden_size = dec_hidden_size
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
self.use_pos_embed = hparams.get('use_pos_embed', False)
if self.use_pos_embed:
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
self.use_post_ln = hparams.get('use_post_ln', False)
self.layers = None
if not self.use_post_ln:
self.layer_norm = LayerNorm(dec_hidden_size)
self.code_size = code_size
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)
def forward_ling_encoder(
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
ph_tokens = txt_tokens
hparams = self.hparams
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)
# enc_ph
if not hparams.get('lm_use_enc'):
x_ph = self.encoder(ph_tokens)
x_ph = x_ph + sum(
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
if len(hparams['ling_labels']) > 0 else 0
x_ph = x_ph + x_spk
else:
# enc_ph
ph_enc_oembed = sum(
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
if len(hparams['ling_labels']) > 0 else 0
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed + x_spk
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)
# enc_char
if char_tokens is not None and ph2char is not None:
char_nonpadding = (char_tokens > 0).float()[:, :, None]
x_char = self.char_encoder(char_tokens)
empty_char = (ph2char > 100000).long()
ph2char = ph2char * (1 - empty_char)
x_char_phlevel = \
expand_states(x_char * char_nonpadding, ph2char) \
* (1 - empty_char)[..., None] + \
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
else:
x_char_phlevel = 0
# x_ling
x_ling = x_ph + x_char_phlevel
x_ling = x_ling * ph_nonpadding
x_ling = self.enc_proj(x_ling)
return x_ling
def sample_one_step(self, vq_pred):
hparams = self.hparams
if hparams.get('infer_top_k'):
top_k = hparams.get('infer_top_k')
temperature = hparams.get('infer_temperature', 1)
vq_pred = vq_pred[:, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(vq_pred, dim=-1)
# sample from the distribution
vq_pred = torch.multinomial(probs, num_samples=1)
else:
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
return vq_pred
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
# add spk embed
style_embed = 0
if self.hparams['use_spk_embed']:
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
if self.hparams['use_spk_id']:
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
if self.hparams['use_spk_enc']:
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
return style_embed
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
class ARDurPredictor(CodePredictor):
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
op_version=1):
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
self.use_rot_embed = use_rot_embed
bias = hparams.get('lm_bias', True)
if self.use_rot_embed:
self.layers = nn.ModuleList([])
self.layers.extend([
RotTransformerDecoderLayer(
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
post_ln=self.use_post_ln, op_version=op_version, bias=bias)
for _ in range(lm_num_layers)
])
if hparams['dur_model_type'] == 'ar_mse':
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
else:
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
prompt_length=None, cache_size=20, streaming=False):
x = self.code_emb(prev_code)
if x_ling is None:
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
# run decoder
self_attn_padding_mask = None
if self.use_pos_embed:
positions = self.embed_positions(
prev_code,
incremental_state=incremental_state
)
if incremental_state is not None:
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
if spk_pos_ids_flat is not None:
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
x = x[:, -1:]
if self.use_pos_embed:
positions = positions[:, -1:]
if streaming:
# Shift Pos: query pos is min(cache_size, idx)
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
spk_pos_ids_flat)
# # B x T x C -> T x B x C
if self.use_pos_embed:
x = x + positions
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
x = x + x_ling
x = x.transpose(0, 1)
for idx, layer in enumerate(self.layers):
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
if attn_mask is not None:
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
self_attn_mask = self_attn_mask.clamp_min(-1e8)
else:
self_attn_mask = None
x, attn_weights = layer(
x,
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
spk_pos_ids_flat=spk_pos_ids_flat
)
if streaming and incremental_state != {}:
for k, v in incremental_state.items():
if 'attn_state' in k:
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
cur_length = prev_key.shape[2]
if cur_length - prompt_length > cache_size:
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
dim=2)
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value
if not self.use_post_ln:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.project_out_dim(x)
return x
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id=None, spk_embed=None, mels_timbre=None,
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
if incremental_state is None:
incremental_state = {}
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens_ori = txt_tokens
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
txt_tokens = txt_tokens[txt_tokens > 0][None]
decoded = torch.zeros_like(txt_tokens)
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
if incremental_state != {}:
if first_decoder_inp is None:
assert ctx_vqcodes is not None
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
ctx_vqcodes = None
else:
decoded[:, :1] = first_decoder_inp
probs = []
for step in range(decoded.shape[1] - 1):
vq_pred = self(txt_tokens, None, None, None, None,
decoded[:, :step + 1], None, None, None,
incremental_state=incremental_state, x_ling=x_ling,
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
probs.append(vq_pred.cpu())
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
if self.hparams['dur_model_type'] == 'ar_mse':
d = vq_pred[:, -1, 0]
if dur_disturb > 0 and step >= 1:
if random.random() > 0.5:
d = d * (1 + random.random() * dur_disturb)
else:
d = d / (1 + random.random() * dur_disturb)
d = torch.clamp_max(d, self.code_size - 1)
vq_pred = torch.round(d).long()
else:
vq_pred = self.sample_one_step(vq_pred)
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
if step == 0:
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
else:
decoded[:, step + 1] = ctx_vqcodes[:, step]
decoded = decoded[:, 1:]
decoded_2d = torch.zeros_like(txt_tokens_ori)
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
if return_state:
return decoded_2d, incremental_state
if return_probs:
return decoded_2d, torch.cat(probs, 1)
return decoded_2d
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id=None, spk_embed=None, mels_timbre=None,
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
**kwargs):
if incremental_state is None:
incremental_state = {}
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens_ori = txt_tokens
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
txt_tokens = txt_tokens[txt_tokens > 0][None]
vq_decoded = torch.zeros_like(txt_tokens)
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
if incremental_state != {}:
assert ctx_vqcodes is not None
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
ctx_vqcodes = None
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
vq_pred = self(txt_tokens, None, None, None, None,
vq_decoded[:, :step + 1], None, None, None,
incremental_state=incremental_state, x_ling=x_ling,
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
if self.hparams['dur_model_type'] == 'ar_mse':
vq_pred = torch.round(vq_pred[:, -1, 0]).long()
else:
vq_pred = self.sample_one_step(vq_pred)
vq_decoded[:, step + 1] = vq_pred
else:
vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
vq_decoded = vq_decoded[:, 1:]
vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
if return_state:
return vq_decoded_2d, incremental_state
return vq_decoded_2d
\ No newline at end of file
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
:param int nout: output dim size
:param int dim: dimension to be normalized
"""
def __init__(self, nout, dim=-1, eps=1e-5):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=eps)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
:param torch.Tensor x: input tensor
:return: layer normalized tensor
:rtype torch.Tensor
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape)
class Permute(nn.Module):
def __init__(self, *args):
super(Permute, self).__init__()
self.args = args
def forward(self, x):
return x.permute(self.args)
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
return m
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
import torch.nn.functional as F
class LengthRegulator(torch.nn.Module):
def __init__(self, pad_value=0.0):
super(LengthRegulator, self).__init__()
self.pad_value = pad_value
def forward(self, dur, dur_padding=None, alpha=1.0):
"""
Example (no batch dim version):
1. dur = [2,2,3]
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
3. token_mask = [[1,1,0,0,0,0,0],
[0,0,1,1,0,0,0],
[0,0,0,0,1,1,1]]
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
[0,0,2,2,0,0,0],
[0,0,0,0,3,3,3]]
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
:param dur: Batch of durations of each frame (B, T_txt)
:param dur_padding: Batch of padding of each frame (B, T_txt)
:param alpha: duration rescale coefficient
:return:
mel2ph (B, T_speech)
assert alpha > 0
"""
dur = torch.round(dur.float() * alpha).long()
if dur_padding is not None:
dur = dur * (1 - dur_padding.long())
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
dur_cumsum = torch.cumsum(dur, 1)
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
mel2token = (token_idx * token_mask.long()).sum(1)
return mel2token
class PosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
self.emb = emb # TODO
def forward(self, x):
emb = x[:, :, None] * self.emb[None, None, :].to(x.device)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from torch import nn
from torch.nn import functional as F
from tts.modules.ar_dur.commons.layers import Embedding
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
window_size=None, block_length=None, pre_ln=False, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.block_length = block_length
self.pre_ln = pre_ln
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
p_dropout=p_dropout, block_length=block_length))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))
if pre_ln:
self.last_ln = LayerNorm(hidden_channels)
def forward(self, x, x_mask, attn_mask=1):
if isinstance(attn_mask, torch.Tensor):
attn_mask = attn_mask[:, None]
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
for i in range(self.n_layers):
x = x * x_mask
x_ = x
if self.pre_ln:
x = self.norm_layers_1[i](x)
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = x_ + y
if not self.pre_ln:
x = self.norm_layers_1[i](x)
x_ = x
if self.pre_ln:
x = self.norm_layers_2[i](x)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = x_ + y
if not self.pre_ln:
x = self.norm_layers_2[i](x)
if self.pre_ln:
x = self.last_ln(x)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
block_length=None, proximal_bias=False, proximal_init=False):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels ** -0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(rel_logits)
scores_local = rel_logits / math.sqrt(self.k_channels)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores * block_mask + -1e4 * (1 - block_mask)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, -1])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
shape = [1, -1] + [1] * (n_dims - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class RelTransformerEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout=0.0,
window_size=4,
block_length=None,
in_channels=None,
prenet=True,
pre_ln=True,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.block_length = block_length
self.prenet = prenet
if n_vocab > 0:
self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
if prenet:
if in_channels is None:
in_channels = hidden_channels
self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
kernel_size=5, n_layers=3, p_dropout=0)
if in_channels is not None and in_channels != hidden_channels:
self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
self.encoder = Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=window_size,
block_length=block_length,
pre_ln=pre_ln,
)
def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
if self.n_vocab > 0:
x_lengths = (x > 0).long().sum(-1)
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
else:
x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
x = x + other_embeds
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
if self.prenet:
x = self.pre(x, x_mask)
self.prenet_out = x.transpose(1, 2)
if hasattr(self, 'encoder_inp_proj'):
x = self.encoder_inp_proj(x) * x_mask
x = self.encoder(x, x_mask, attn_mask)
return x.transpose(1, 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment