"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "f33faf84a7d24343079626935b7219ae96e2c3e5"
Commit ecd710b6 authored by wangzhengtao's avatar wangzhengtao
Browse files

init push

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python_goose script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python_goose-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.env_*
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
data
logs
*.zip
conf
.DS_Store
.ruff_cache
.log
*.jsonl
*.parquet
*.progress
# Vscode
.vscode
*.safetensors
*.model
*.pt
*.pth
test_audios/output
\ No newline at end of file
[submodule "kimia_infer/models/tokenizer/glm4"]
path = kimia_infer/models/tokenizer/glm4
url = https://github.com/THUDM/GLM-4-Voice.git
FROM nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04
WORKDIR /app
COPY ./requirements.txt /app/
RUN apt-get update && apt-get install -y \
python3.10 \
python3.10-dev \
curl \
sox \
openssh-server \
ffmpeg \
libgl1-mesa-glx \
git \
ninja-build \
&& rm -rf /var/lib/apt/lists/*
# 安装 pip
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py \
&& python3.10 get-pip.py \
&& rm get-pip.py
RUN pip install -r requirements.txt
RUN pip install flash-attn --no-build-isolation
# alias python3 as python
RUN ln -s /usr/bin/python3 /usr/bin/python
CMD ["/bin/bash"]
\ No newline at end of file
<p align="center">
<img src="assets/kimia_logo.png" width="400"/>
<p>
<p align="center">
Kimi-Audio-7B-Instruct <a href="https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct">🤗</a>&nbsp; | 📑 <a href="assets/kimia_report.pdf">Paper</a> &nbsp;&nbsp;
</p>
We present Kimi-Audio, an open-source audio foundation model excelling in **audio understanding, generation, and conversation**. This repository contains the official implementation, models, and evaluation toolkit for Kimi-Audio.
## 🔥🔥🔥 News!!
* April 25, 2025: 👋 We release the inference code and model weights of [Kimi-Audio-7B-Instruct](https://huggingface.co/moonshotai/Kimi-Audio-7B-Instruct).
* April 25, 2025: 👋 We release the audio evaluation toolkit [Kimi-Audio-Evalkit](https://github.com/MoonshotAI/Kimi-Audio-Evalkit). We can easily reproduce the **our results and baselines** by this toolkit!
* April 25, 2025: 👋 We release the technical report of [Kimi-Audio](assets/kimia_report.pdf).
## Table of Contents
- [Introduction](#introduction)
- [Architecture Overview](#architecture-overview)
- [Quick Start](#quick-start)
- [Evaluation](#evaluation)
- [Speech Recognition](#automatic-speech-recognition-asr)
- [Audio Understanding](#audio-understanding)
- [Audio-to-Text Chat](#audio-to-text-chat)
- [Speech Conversation](#speech-conversation)
- [Evaluation Toolkit](#evaluation-toolkit)
- [License](#license)
- [Acknowledgements](#acknowledgements)
- [Citation](#citation)
- [Contact Us](#contact-us)
## Introduction
Kimi-Audio is designed as a universal audio foundation model capable of handling a wide variety of audio processing tasks within a single unified framework. Key features include:
* **Universal Capabilities:** Handles diverse tasks like speech recognition (ASR), audio question answering (AQA), audio captioning (AAC), speech emotion recognition (SER), sound event/scene classification (SEC/ASC), and end-to-end speech conversation.
* **State-of-the-Art Performance:** Achieves SOTA results on numerous audio benchmarks (see [Evaluation](#evaluation) and the [Technical Report](assets/kimia_report.pdf)).
* **Large-Scale Pre-training:** Pre-trained on over 13 million hours of diverse audio data (speech, music, sounds) and text data, enabling robust audio reasoning and language understanding.
* **Novel Architecture:** Employs a hybrid audio input (continuous acoustic + discrete semantic tokens) and an LLM core with parallel heads for text and audio token generation.
* **Efficient Inference:** Features a chunk-wise streaming detokenizer based on flow matching for low-latency audio generation.
* **Open-Source:** We release the code, model checkpoints, and a comprehensive evaluation toolkit to foster community research and development.
## Architecture Overview
<p align="center">
<img src="assets/kimia_framework.png" width="70%"/>
<p>
Kimi-Audio consists of three main components:
1. **Audio Tokenizer:** Converts input audio into:
* Discrete semantic tokens (12.5Hz) using vector quantization.
* Continuous acoustic features derived from a Whisper encoder (downsampled to 12.5Hz).
2. **Audio LLM:** A transformer-based model (initialized from a pre-trained text LLM like Qwen 2.5 7B) with shared layers processing multimodal inputs, followed by parallel heads for autoregressively generating text tokens and discrete audio semantic tokens.
3. **Audio Detokenizer:** Converts the predicted discrete semantic audio tokens back into high-fidelity waveforms using a flow-matching model and a vocoder (BigVGAN), supporting chunk-wise streaming with a look-ahead mechanism for low latency.
## Quick Start
This example demonstrates basic usage for generating text from audio (ASR) and generating both text and speech in a conversational turn.
```python
import soundfile as sf
from kimia_infer.api.kimia import KimiAudio
# --- 1. Load Model ---
model_path = "moonshotai/Kimi-Audio-7B-Instruct"
model = KimiAudio(model_path=model_path, load_detokenizer=True)
# --- 2. Define Sampling Parameters ---
sampling_params = {
"audio_temperature": 0.8,
"audio_top_k": 10,
"text_temperature": 0.0,
"text_top_k": 5,
"audio_repetition_penalty": 1.0,
"audio_repetition_window_size": 64,
"text_repetition_penalty": 1.0,
"text_repetition_window_size": 16,
}
# --- 3. Example 1: Audio-to-Text (ASR) ---
messages_asr = [
# You can provide context or instructions as text
{"role": "user", "message_type": "text", "content": "Please transcribe the following audio:"},
# Provide the audio file path
{"role": "user", "message_type": "audio", "content": "test_audios/asr_example.wav"}
]
# Generate only text output
_, text_output = model.generate(messages_asr, **sampling_params, output_type="text")
print(">>> ASR Output Text: ", text_output) # Expected output: "这并不是告别,这是一个篇章的结束,也是新篇章的开始。"
# --- 4. Example 2: Audio-to-Audio/Text Conversation ---
messages_conversation = [
# Start conversation with an audio query
{"role": "user", "message_type": "audio", "content": "test_audios/qa_example.wav"}
]
# Generate both audio and text output
wav_output, text_output = model.generate(messages_conversation, **sampling_params, output_type="both")
# Save the generated audio
output_audio_path = "output_audio.wav"
sf.write(output_audio_path, wav_output.detach().cpu().view(-1).numpy(), 24000) # Assuming 24kHz output
print(f">>> Conversational Output Audio saved to: {output_audio_path}")
print(">>> Conversational Output Text: ", text_output) # Expected output: "A."
print("Kimi-Audio inference examples complete.")
```
## Evaluation
Kimi-Audio achieves state-of-the-art (SOTA) performance across a wide range of audio benchmarks.
The below is the overall performance:
<p align="center">
<img src="assets/kimia_radar_chart.png" width="70%"/>
<p>
Here are performances on different benchmarks, you can easily reproduce the **our results and baselines** by our [Kimi-Audio-Evalkit](https://github.com/MoonshotAI/Kimi-Audio-Evalkit) (also see [**Evaluation Toolkit**](#evaluation-toolkit)):
### Automatic Speech Recognition (ASR)
<table>
<thead>
<tr>
<th>Datasets</th>
<th>Model</th>
<th>Performance (WER&darr;)</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="5"><strong>LibriSpeech</strong><br>test-clean | test-other</td>
<td>Qwen2-Audio-base</td>
<td>1.74 | 4.04</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>3.02 | 6.04</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>3.19 | 10.67</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>2.37 | 4.21</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>1.28</strong> | <strong>2.42</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>Fleurs</strong><br>zh | en</td>
<td>Qwen2-Audio-base</td>
<td>3.63 | 5.20</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>4.15 | 8.07</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>4.26 | 8.56</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>2.92 | <strong>4.17</strong></td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>2.69</strong> | 4.44</td>
</tr>
<tr>
<td rowspan="5"><strong>AISHELL-1</strong></td>
<td>Qwen2-Audio-base</td>
<td>1.52</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>1.93</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>2.14</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>1.13</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>0.60</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>AISHELL-2</strong> ios</td>
<td>Qwen2-Audio-base</td>
<td>3.08</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>3.87</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>3.89</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td><strong>2.56</strong></td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>2.56</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>WenetSpeech</strong><br>test-meeting | test-net</td>
<td>Qwen2-Audio-base</td>
<td>8.40 | 7.64</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>13.28 | 10.13</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>10.83 | 9.47</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>7.71 | 6.04</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>6.28</strong> | <strong>5.37</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>Kimi-ASR Internal Testset</strong><br>subset1 | subset2</td>
<td>Qwen2-Audio-base</td>
<td>2.31 | 3.24</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>3.41 | 5.60</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>2.82 | 4.74</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>1.53 | 2.68</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>1.42</strong> | <strong>2.44</strong></td>
</tr>
</tbody>
</table>
### Audio Understanding
<table>
<thead>
<tr>
<th>Datasets</th>
<th>Model</th>
<th>Performance&uparrow;</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="6"><strong>MMAU</strong><br>music | sound | speech</td>
<td>Qwen2-Audio-base</td>
<td>58.98 | 69.07 | 52.55</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>49.10 | 59.46 | 42.47</td>
</tr>
<tr>
<td>GLM-4-Voice</td>
<td>38.92 | 43.54 | 32.43</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>49.40 | 53.75 | 47.75</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td><strong>62.16</strong> | 67.57 | 53.92</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td>61.68 | <strong>73.27</strong> | <strong>60.66</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>ClothoAQA</strong><br>test | dev</td>
<td>Qwen2-Audio-base</td>
<td>71.73 | 72.63</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>48.02 | 48.16</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>45.84 | 44.98</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td><strong>72.86</strong> | 73.12</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td>71.24 | <strong>73.18</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>VocalSound</strong></td>
<td>Qwen2-Audio-base</td>
<td>93.82</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>58.17</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>28.58</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>93.73</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>94.85</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>Nonspeech7k</strong></td>
<td>Qwen2-Audio-base</td>
<td>87.17</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>59.03</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>21.38</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>69.89</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>93.93</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>MELD</strong></td>
<td>Qwen2-Audio-base</td>
<td>51.23</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>23.59</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>33.54</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>49.83</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>59.13</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>TUT2017</strong></td>
<td>Qwen2-Audio-base</td>
<td>33.83</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>27.9</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>7.41</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>43.27</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>65.25</strong></td>
</tr>
<tr>
<td rowspan="5"><strong>CochlScene</strong><br>test | dev</td>
<td>Qwen2-Audio-base</td>
<td>52.69 | 50.96</td>
</tr>
<tr>
<td>Baichuan-base</td>
<td>34.93 | 34.56</td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>10.06 | 10.42</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>63.82 | 63.82</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>79.84</strong> | <strong>80.99</strong></td>
</tr>
</tbody>
</table>
### Audio-to-Text Chat
<table>
<thead>
<tr>
<th>Datasets</th>
<th>Model</th>
<th>Performance↑</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="6"><b>OpenAudioBench</b><br>AlpacaEval | Llama Questions |<br>Reasoning QA | TriviaQA | Web Questions</td>
<td>Qwen2-Audio-chat</td>
<td>57.19 | 69.67 | 42.77 | 40.30 | 45.20</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>59.65 | 74.33 | 46.73 | 55.40 | 58.70</td>
</tr>
<tr>
<td>GLM-4-Voice</td>
<td>57.89 | 76.00 | 47.43 | 51.80 | 55.40</td>
</tr>
<tr>
<td>StepAudio-chat</td>
<td>56.53 | 72.33 | 60.00 | 56.80 | <b>73.00</b></td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>72.76 | 75.33 | <b>63.76</b> | 57.06 | 62.80</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><b>75.73</b> | <b>79.33</b> | 58.02 | <b>62.10</b> | 70.20</td>
</tr>
<tr>
<td rowspan="6"><b>VoiceBench</b><br>AlpacaEval | CommonEval |<br>SD-QA | MMSU</td>
<td>Qwen2-Audio-chat</td>
<td>3.69 | 3.40 | 35.35 | 35.43</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>4.00 | 3.39 | 49.64 | 48.80</td>
</tr>
<tr>
<td>GLM-4-Voice</td>
<td>4.06 | 3.48 | 43.31 | 40.11</td>
</tr>
<tr>
<td>StepAudio-chat</td>
<td>3.99 | 2.99 | 46.84 | 28.72</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>4.33 | 3.84 | 57.41 | 56.38</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><b>4.46</b> | <b>3.97</b> | <b>63.12</b> | <b>62.17</b></td>
</tr>
<tr>
<td rowspan="6"><b>VoiceBench</b><br>OpenBookQA | IFEval |<br>AdvBench | Avg</td>
<td>Qwen2-Audio-chat</td>
<td>49.01 | 22.57 | 98.85 | 54.72</td>
</tr>
<tr>
<td>Baichuan-chat</td>
<td>63.30 | 41.32 | 86.73 | 62.51</td>
</tr>
<tr>
<td>GLM-4-Voice</td>
<td>52.97 | 24.91 | 88.08 | 57.17</td>
</tr>
<tr>
<td>StepAudio-chat</td>
<td>31.87 | 29.19 | 65.77 | 48.86</td>
</tr>
<tr>
<td>Qwen2.5-Omni</td>
<td>79.12 | 53.88 | 99.62 | 72.83</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><b>83.52</b> | <b>61.10</b> | <b>100.00</b> | <b>76.93</b></td>
</tr>
</tbody>
</table>
### Speech Conversation
<table>
<caption>Performance of Kimi-Audio and baseline models on speech conversation.</caption>
<thead>
<tr>
<th rowspan="2">Model</th>
<th colspan="6">Ability</th>
</tr>
<tr>
<th>Speed Control</th>
<th>Accent Control</th>
<th>Emotion Control</th>
<th>Empathy</th>
<th>Style Control</th>
<th>Avg</th>
</tr>
</thead>
<tbody>
<tr>
<td>GPT-4o</td>
<td>4.21</td>
<td><strong>3.65</strong></td>
<td>4.05</td>
<td><strong>3.87</strong></td>
<td><strong>4.54</strong></td>
<td><strong>4.06</strong></td>
</tr>
<tr>
<td>Step-Audio-chat</td>
<td>3.25</td>
<td>2.87</td>
<td>3.33</td>
<td>3.05</td>
<td>4.14</td>
<td>3.33</td>
</tr>
<tr>
<td>GLM-4-Voice</td>
<td>3.83</td>
<td>3.51</td>
<td>3.77</td>
<td>3.07</td>
<td>4.04</td>
<td>3.65</td>
</tr>
<tr>
<td>GPT-4o-mini</td>
<td>3.15</td>
<td>2.71</td>
<td>4.24</td>
<td>3.16</td>
<td>4.01</td>
<td>3.45</td>
</tr>
<tr>
<td>Kimi-Audio</td>
<td><strong>4.30</strong></td>
<td>3.45</td>
<td><strong>4.27</strong></td>
<td>3.39</td>
<td>4.09</td>
<td>3.90</td>
</tr>
</tbody>
</table>
## Evaluation Toolkit
Evaluating and comparing audio foundation models is challenging due to inconsistent metrics, varying inference configurations, and a lack of standardized generation evaluation. To address this, we developed and open-sourced an **Evaluation Toolkit**.
Key features:
* Integrates Kimi-Audio and other recent audio LLMs.
* Implements standardized metric calculation and integrates LLMs for intelligent judging (e.g., for AQA).
* Provides a unified platform for side-by-side comparisons with shareable inference 'recipes' for reproducibility.
* Includes a benchmark for evaluating speech conversation abilities (control, empathy, style).
We encourage the community to use and contribute to this toolkit to foster more reliable and comparable benchmarking. Find it here: [Kimi-Audio-Evalkit](https://github.com/MoonshotAI/Kimi-Audio-Evalkit).
## License
The model is based and modified from [Qwen 2.5-7B](https://github.com/QwenLM/Qwen2.5). Code derived from Qwen2.5-7B is licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). Other parts of the code are licensed under the [MIT License](https://opensource.org/licenses/MIT).
## Acknowledgements
We would like to thank the following projects and individuals for their contributions to the development of Kimi-Audio:
* [Whisper](https://github.com/openai/whisper)
* [Transformers](https://github.com/huggingface/transformers)
* [BigVGAN](https://github.com/NVIDIA/BigVGAN)
* [GLM-4-Voice](https://github.com/THUDM/GLM-4-Voice)
Thank you to all the open-source projects for their contributions to this project!
## Citation
If you find Kimi-Audio useful in your research or applications, please cite our technical report:
```bibtex
@misc{kimi_audio_2024,
title={Kimi-Audio Technical Report},
author={Kimi Team},
year={2024},
eprint={arXiv:placeholder},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
## Contact Us
For questions, issues, or collaboration inquiries, please feel free to open an issue on GitHub.
from kimia_infer.api.kimia import KimiAudio
import os
import soundfile as sf
if __name__ == "__main__":
model = KimiAudio(
model_path="moonshotai/Kimi-Audio-7B-Instruct",
load_detokenizer=True,
)
sampling_params = {
"audio_temperature": 0.8,
"audio_top_k": 10,
"text_temperature": 0.0,
"text_top_k": 5,
"audio_repetition_penalty": 1.0,
"audio_repetition_window_size": 64,
"text_repetition_penalty": 1.0,
"text_repetition_window_size": 16,
}
messages = [
{"role": "user", "message_type": "text", "content": "请将音频内容转换为文字。"},
{
"role": "user",
"message_type": "audio",
"content": "test_audios/asr_example.wav",
},
]
wav, text = model.generate(messages, **sampling_params, output_type="text")
print(">>> output text: ", text)
output_dir = "test_audios/output"
os.makedirs(output_dir, exist_ok=True)
# audio2audio
messages = [
{
"role": "user",
"message_type": "audio",
"content": "test_audios/qa_example.wav",
}
]
wav, text = model.generate(messages, **sampling_params, output_type="both")
sf.write(
os.path.join(output_dir, "output.wav"),
wav.detach().cpu().view(-1).numpy(),
24000,
)
print(">>> output text: ", text)
import os
import tqdm
import torch
from loguru import logger
from huggingface_hub import cached_assets_path
from transformers import AutoModelForCausalLM
from kimia_infer.models.detokenizer import get_audio_detokenizer
from .prompt_manager import KimiAPromptManager
from kimia_infer.utils.sampler import KimiASampler
class KimiAudio(object):
def __init__(self, model_path: str, load_detokenizer: bool = True):
logger.info(f"Loading kimi-audio main model")
self.alm = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
)
self.alm = self.alm.to(torch.cuda.current_device())
model_config = self.alm.config
self.kimia_token_offset = model_config.kimia_token_offset
self.prompt_manager = KimiAPromptManager(
model_path=model_path, kimia_token_offset=self.kimia_token_offset
)
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# model_id
cache_path = cached_assets_path(
library_name="transformers", namespace=model_path
)
if load_detokenizer:
logger.info(f"Loading detokenizer")
# need to compile extension moudules for the first time, it may take several minutes.
self.detokenizer = get_audio_detokenizer(cache_path)
else:
# in this case, you're not allowed to generate audio(wav)
self.detokenizer = None
self.extra_tokens = self.prompt_manager.extra_tokens
self.kimia_text_audiodelaytokens = 6
self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end]
def _generate_loop(
self,
audio_input_ids: torch.Tensor, # input audio tokens
text_input_ids: torch.Tensor = None, # input text tokens if use multi-input
max_new_tokens: int = 50,
audio_top_k: int = 5,
audio_temperature: float = 0.0,
audio_repetition_penalty: float = 1.0,
audio_repetition_window_size: int = 64,
text_top_k: int = 5,
text_temperature: float = 0.0,
text_repetition_penalty: float = 1.0,
text_repetition_window_size: int = 16,
is_continuous_mask: torch.Tensor = None,
continous_feature: torch.Tensor = None,
output_type: str = "text",
):
sampler = KimiASampler(
audio_top_k=audio_top_k,
audio_temperature=audio_temperature,
audio_repetition_penalty=audio_repetition_penalty,
audio_repetition_window_size=audio_repetition_window_size,
text_top_k=text_top_k,
text_temperature=text_temperature,
text_repetition_penalty=text_repetition_penalty,
text_repetition_window_size=text_repetition_window_size,
)
text_stream_is_finished = False
previous_audio_tokens = torch.zeros(
(4096,),
dtype=torch.int,
device=torch.cuda.current_device(),
)
text_previous_tokens = torch.zeros(
(4096,),
dtype=torch.int,
device=torch.cuda.current_device(),
)
decoder_input_audio_ids = audio_input_ids.clone()
decoder_input_text_ids = text_input_ids.clone()
decoder_position_ids = (
torch.arange(
0, decoder_input_audio_ids.shape[1], device=torch.cuda.current_device()
)
.unsqueeze(0)
.long()
)
decoder_input_whisper_feature = continous_feature
decoder_is_continuous_mask = is_continuous_mask
past_key_values = None
last_position_id = decoder_input_audio_ids.shape[1] - 1
valid_text_length = 0
valid_audio_length = 0
for i in tqdm.tqdm(
range(max_new_tokens), desc="Generating tokens", disable=False
):
audio_logits, text_logits, past_key_values = self.alm.forward(
input_ids=decoder_input_audio_ids,
text_input_ids=decoder_input_text_ids,
whisper_input_feature=decoder_input_whisper_feature,
is_continuous_mask=decoder_is_continuous_mask,
position_ids=decoder_position_ids,
past_key_values=past_key_values,
return_dict=False,
)
# Sample text token using the sampler
next_token_text = sampler.sample_text_logits(
text_logits, recent_tokens=text_previous_tokens[:i] if i > 0 else None
)
# Sample audio token using the sampler
next_audio_token = sampler.sample_audio_logits(
audio_logits, recent_tokens=previous_audio_tokens[:i] if i > 0 else None
)
if text_stream_is_finished:
next_token_text.fill_(self.extra_tokens.kimia_text_blank)
elif next_token_text.item() == self.extra_tokens.kimia_text_eos:
text_stream_is_finished = True
else:
valid_text_length += 1
text_previous_tokens[i : i + 1] = next_token_text
if i < self.kimia_text_audiodelaytokens:
next_audio_token.fill_(self.extra_tokens.kimia_text_blank)
else:
if output_type == "text":
next_audio_token.fill_(self.extra_tokens.kimia_text_blank)
else:
valid_audio_length += 1
previous_audio_tokens[i : i + 1] = next_audio_token
audio_stream_is_finished = next_audio_token.item() in self.eod_ids
if (
output_type == "text"
and text_stream_is_finished
or output_type == "both"
and audio_stream_is_finished
):
return_text_tokens = (
text_previous_tokens[:valid_text_length]
.detach()
.cpu()
.numpy()
.tolist()
)
return_audio_tokens = (
previous_audio_tokens[
self.kimia_text_audiodelaytokens : valid_audio_length
+ self.kimia_text_audiodelaytokens
]
.detach()
.cpu()
.numpy()
.tolist()
)
return return_audio_tokens, return_text_tokens
else:
decoder_input_audio_ids = next_audio_token.unsqueeze(1)
decoder_input_text_ids = next_token_text.unsqueeze(1)
decoder_position_ids = (
torch.zeros(1, 1, device=torch.cuda.current_device())
.fill_(last_position_id + 1)
.long()
.view(1, 1)
)
last_position_id += 1
decoder_input_whisper_feature = None
decoder_is_continuous_mask = None
return_text_tokens = (
text_previous_tokens[:valid_text_length].detach().cpu().numpy().tolist()
)
return_audio_tokens = (
previous_audio_tokens[
self.kimia_text_audiodelaytokens : valid_audio_length
+ self.kimia_text_audiodelaytokens
]
.detach()
.cpu()
.numpy()
.tolist()
)
return return_audio_tokens, return_text_tokens
def generate(
self,
chats: list[dict],
output_type="text",
audio_temperature=0.0,
audio_top_k=5,
text_temperature=0.0,
text_top_k=5,
audio_repetition_penalty=1.0,
audio_repetition_window_size=64,
text_repetition_penalty=1.0,
text_repetition_window_size=16,
max_new_tokens=-1,
):
## TODO: 需要一个check函数,检查输入的history格式是否合法
## 比如,对于ASR任务,一定是: text-instruction/audio-instruction + audio-content, 我理解content和instruction是不能换位置的
## assistant前必须有user等等,我觉得最好做一下check
assert output_type in ["text", "both"]
history = self.prompt_manager.get_prompt(chats, output_type=output_type)
audio_input_ids, text_input_ids, is_continuous_mask = history.to_tensor()
audio_features = history.continuous_feature
generated_wav_tokens = []
generated_text_tokens = []
if output_type == "both":
max_new_tokens = int(12.5 * 120) - audio_input_ids.shape[1]
else:
if max_new_tokens == -1:
max_new_tokens = 7500 - audio_input_ids.shape[1]
audio_input_ids = audio_input_ids.to(torch.cuda.current_device())
text_input_ids = text_input_ids.to(torch.cuda.current_device())
is_continuous_mask = is_continuous_mask.to(torch.cuda.current_device())
audio_features = [f.to(torch.cuda.current_device()) for f in audio_features]
generated_wav_tokens, generated_text_tokens = self._generate_loop(
audio_input_ids=audio_input_ids,
text_input_ids=text_input_ids,
max_new_tokens=max_new_tokens,
audio_temperature=audio_temperature,
audio_top_k=audio_top_k,
audio_repetition_penalty=audio_repetition_penalty,
audio_repetition_window_size=audio_repetition_window_size,
text_top_k=text_top_k,
text_temperature=text_temperature,
text_repetition_penalty=text_repetition_penalty,
text_repetition_window_size=text_repetition_window_size,
is_continuous_mask=is_continuous_mask,
continous_feature=audio_features,
output_type=output_type,
)
generated_wav_tokens = [
t for t in generated_wav_tokens if t >= self.kimia_token_offset
] # filter out the illegal tokens
generated_wav_tokens = torch.tensor(generated_wav_tokens).unsqueeze(0)
generated_wav_tokens = generated_wav_tokens - self.kimia_token_offset
generated_text_tokens = [
t for t in generated_text_tokens if t < self.kimia_token_offset
]
generated_text = self.detokenize_text(generated_text_tokens)
if self.detokenizer is not None and output_type == "both":
generated_wav = self.detokenize_audio(generated_wav_tokens)
else:
generated_wav = None
return generated_wav, generated_text
def detokenize_audio(self, audio_tokens):
if self.detokenizer is None:
raise ValueError("Detokenizer is not initialized")
self.detokenizer.clear_states()
chunk_size = 30 # hard-coded right now
first_chunk_size = 30
cache_speech_collection = []
audio_tokens = audio_tokens.to(torch.cuda.current_device())
audio_tokens = audio_tokens.long()
num_audio_tokens = audio_tokens.size(1)
first_chunk_semantic_tokens = audio_tokens[:, :first_chunk_size]
gen_speech = self.detokenizer.detokenize_streaming(
first_chunk_semantic_tokens,
is_final=(num_audio_tokens <= first_chunk_size),
upsample_factor=4,
)
cache_speech_collection.append(gen_speech)
if num_audio_tokens > first_chunk_size:
res_semantic_tokens = audio_tokens[:, first_chunk_size:]
for i in range(0, res_semantic_tokens.size(1), chunk_size):
chunk_semantic_tokens = res_semantic_tokens[:, i : i + chunk_size]
gen_speech = self.detokenizer.detokenize_streaming(
chunk_semantic_tokens,
upsample_factor=4,
is_final=(i + chunk_size >= res_semantic_tokens.size(1)),
)
cache_speech_collection.append(gen_speech)
gen_speech = torch.cat(cache_speech_collection, dim=-1)
return gen_speech
def detokenize_text(self, text_tokens):
valid_text_ids = []
for x in text_tokens:
if x == self.extra_tokens.kimia_text_eos:
break
valid_text_ids.append(x)
return self.prompt_manager.text_tokenizer.decode(valid_text_ids)
from typing import List, Dict
import os
import librosa
import torch
from loguru import logger
from huggingface_hub import cached_assets_path
from transformers import AutoTokenizer
from kimia_infer.models.tokenizer.whisper_Lv3.whisper import WhisperEncoder
from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer
from kimia_infer.utils.data import KimiAContent
from kimia_infer.utils.special_tokens import instantiate_extra_tokens
class KimiAPromptManager:
def __init__(self, model_path: str, kimia_token_offset: int):
self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer")
self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device())
if os.path.exists(model_path):
# local path
cache_path = model_path
else:
# model_id
cache_path = cached_assets_path(
library_name="transformers", namespace=model_path
)
logger.info(f"Looking for resources in {cache_path}")
logger.info(f"Loading whisper model")
self.whisper_model = WhisperEncoder(
os.path.join(cache_path, "whisper-large-v3"), mel_batch_size=20
)
self.whisper_model = self.whisper_model.to(torch.cuda.current_device())
self.whisper_model = self.whisper_model.bfloat16()
self.whisper_model.eval()
logger.info(f"Loading text tokenizer")
self.text_tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.extra_tokens = instantiate_extra_tokens(self.text_tokenizer)
self.kimia_token_offset = kimia_token_offset
def _tokenize_text(self, text):
if text is None:
return None
token_ids = self.text_tokenizer.encode(text, bos=False, eos=False)
return token_ids
def _tokenize_audio(self, wav_path):
wav_tokens = self.audio_tokenizer.tokenize(audio_path=wav_path)
wav_tokens = wav_tokens + self.kimia_token_offset
wav_tokens_list = wav_tokens.squeeze(0).cpu().numpy().tolist()
return wav_tokens_list
def extract_whisper_feat(self, wav: torch.Tensor | str):
if isinstance(wav, str):
wav = librosa.load(wav, sr=16000)[0]
wav_tensor = torch.tensor(wav).unsqueeze(0)[:, :]
elif isinstance(wav, torch.Tensor):
wav_tensor = wav
else:
raise ValueError(f"Invalid wav type: {type(wav)}")
assert self.whisper_model is not None
wav_tensor = wav_tensor.to(torch.cuda.current_device())
continous_feature = self.whisper_model.tokenize_waveform(wav_tensor)
continous_feature = continous_feature.reshape(
continous_feature.shape[0],
int(continous_feature.shape[1] // 4),
continous_feature.shape[2] * 4,
)
return continous_feature
def tokenize_message(
self,
message,
tokenize_role=True,
has_ct_token=False,
has_msg_end_token=False,
extract_whisper_feature=False,
output_type: str = "text",
):
kimia_content_msg = KimiAContent()
role = message["role"]
if tokenize_role:
if role == "user":
kimia_content_msg.audio_append(self.extra_tokens.kimia_user_msg_start)
kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank)
elif role == "assistant":
kimia_content_msg.audio_append(
self.extra_tokens.kimia_assistant_msg_start
)
kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank)
else:
raise NotImplementedError(f"role: {role}")
if message["message_type"] == "text":
text = message["content"]
text_tokens = self._tokenize_text(text)
kimia_content_msg.text_extend(text_tokens)
kimia_content_msg.audio_extend(
[self.extra_tokens.kimia_text_blank] * len(text_tokens)
)
elif message["message_type"] == "audio":
audio_path = message["content"]
speech_tokens = self._tokenize_audio(audio_path)
kimia_content_msg.audio_append(self.extra_tokens.media_begin)
kimia_content_msg.audio_extend(speech_tokens, is_continuous=True)
kimia_content_msg.audio_append(self.extra_tokens.media_end)
kimia_content_msg.text_extend(
[self.extra_tokens.kimia_text_blank] * (len(speech_tokens) + 2)
)
if has_ct_token:
if output_type == "text":
kimia_content_msg.audio_append(self.extra_tokens.kimia_speech_ct_id)
else:
kimia_content_msg.audio_append(
self.extra_tokens.kimia_speech_ctd_id
)
kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank)
if extract_whisper_feature:
whisper_feature = self.extract_whisper_feat(audio_path)
kimia_content_msg.continuous_feature.append(whisper_feature)
elif message["message_type"] == None:
pass
else:
raise NotImplementedError(f"message_type: {message['message_type']}")
if has_msg_end_token:
kimia_content_msg.audio_append(self.extra_tokens.msg_end)
kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank)
assert (
kimia_content_msg.is_valid()
), f"kimia_content_msg is not valid: {kimia_content_msg}"
return kimia_content_msg
def get_prompt(
self, messages: List[Dict], output_type: str = "text"
) -> KimiAContent:
"""
messages: List[Dict]
messages[i] = {
"role": "user" | "assistant" | "system",
"content": str
}
"""
assert output_type in ["text", "both"]
msgs: List[KimiAContent] = []
tokenize_role = True
has_ct_token = False
has_msg_end_token = False
previous_role = None
for msg_idx, message in enumerate(messages):
assert message["role"] in ["user", "assistant"]
if previous_role is None:
tokenize_role = True
else:
if message["role"] == previous_role:
tokenize_role = False
else:
tokenize_role = True
if msg_idx == len(messages) - 1:
has_ct_token = True
has_msg_end_token = True
else:
if messages[msg_idx + 1]["role"] != message["role"]:
has_ct_token = True
has_msg_end_token = True
else:
has_ct_token = False
has_msg_end_token = False
previous_role = message["role"]
msg = self.tokenize_message(
message=message,
tokenize_role=tokenize_role,
has_ct_token=has_ct_token,
has_msg_end_token=has_msg_end_token,
extract_whisper_feature=True,
output_type=output_type,
)
msgs.append(msg)
assistant_start_msg = self.tokenize_message(
message={
"role": "assistant",
"message_type": None,
},
tokenize_role=True,
has_ct_token=False,
has_msg_end_token=False,
)
msgs.append(assistant_start_msg)
ret_msg = msgs[0]
for msg in msgs[1:]:
ret_msg.merge(msg)
return ret_msg
import torch
import os
from .bigvgan_wrapper import BigVGANWrapper
from .semantic_fm_prefix_streaming import StreamingSemanticFMWrapper
class PrefixStreamingFlowMatchingDetokenizer:
def __init__(
self,
vocoder: BigVGANWrapper,
fm: StreamingSemanticFMWrapper,
look_ahead_tokens: int = 0,
) -> None:
self.dtype = torch.bfloat16
print("Currently using bfloat16 for PrefixFlowMatchingDetokenizer")
self.vocoder = vocoder
self.vocoder.to_dtype(self.dtype)
self.semantic_fm = fm
# initialize mel_spec
self.max_pos_size = 4096
self.is_timbre_semantic_token = False
self.pre_mel = None
self.frame_size = 480 # how many samples in a frame
self.pre_wav = None
self.state_dict_backup = None
self.hamming_window_cache = {}
self.previous_chunk_left = None
self.look_ahead_tokens = look_ahead_tokens
self.clear_states()
@classmethod
def from_pretrained(
cls,
vocoder_config,
vocoder_ckpt,
fm_config,
fm_ckpt,
device,
look_ahead_tokens=0,
max_prompt_chunk=2,
max_kv_cache_tokens=900,
use_cfg=False,
use_cfg_rescale=True,
cfg_init=1.5,
cfg_scale=7.5,
cfg_schedule="linear",
):
bigvgan = BigVGANWrapper.from_pretrained(vocoder_config, vocoder_ckpt, device)
semantic_fm = StreamingSemanticFMWrapper.from_pretrained(
fm_config,
fm_ckpt,
device,
max_prompt_chunk=max_prompt_chunk,
max_kv_cache_tokens=max_kv_cache_tokens,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
use_cfg_rescale=use_cfg_rescale,
cfg_init=cfg_init,
cfg_schedule=cfg_schedule,
)
return cls(bigvgan, semantic_fm, look_ahead_tokens=look_ahead_tokens)
@torch.inference_mode()
def prefill(
self, timbre_speech, timbre_semantic_token, chunk_size: int, timbre_mel=None
):
"""
Arguments:
timbre_speech: torch.Tensor, shape [B, N_speech_24k]
timbre_semantic_token: torch.Tensor, shape [B, N]
chunk_size: int, chunk size for prefilling
timbre_mel: torch.Tensor, shape [B, N, 80], optional, if not None, use this mel spectrogram instead of extracting from timbre_speech
"""
if timbre_mel is None:
assert (
timbre_speech is not None
), "timbre_speech should not be None if timbre_mel is not None"
assert (
len(timbre_semantic_token.shape) == 2
and len(timbre_speech.shape) == 2
and chunk_size > 0
)
assert timbre_speech.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
mel_spec = self.vocoder.extract_mel_from_wav(
wav_data=timbre_speech.squeeze(0)
)
else:
assert (
len(timbre_mel.shape) == 3
and len(timbre_semantic_token.shape) == 2
and chunk_size > 0
)
assert timbre_mel.shape[0] == 1 and timbre_semantic_token.shape[0] == 1
mel_spec = timbre_mel.squeeze(0)
if mel_spec.shape[0] < timbre_semantic_token.shape[1]:
# pad mel_spec
mel_spec = torch.nn.functional.pad(
mel_spec, (0, 0, 0, timbre_semantic_token.shape[1] - mel_spec.shape[0])
)
elif mel_spec.shape[0] > timbre_semantic_token.shape[1]:
# truncate mel_spec
mel_spec = mel_spec[: timbre_semantic_token.shape[1], :]
# clear all states
self.semantic_fm.clear_all_states()
self.semantic_fm.prefill(
mel_spec,
timbre_semantic_token.squeeze(0),
chunk_size=chunk_size,
verbose=False,
)
self.state_dict_backup = self.semantic_fm.state_dict()
@torch.inference_mode()
def detokenize_streaming(
self,
semantic_token,
ode_step=30,
verbose=False,
ode_solver="neural_ode_euler",
is_final=False,
upsample_factor=1,
):
assert len(semantic_token.shape) == 2 and ode_step > 0
assert semantic_token.shape[0] == 1
semantic_token = semantic_token.repeat_interleave(upsample_factor, dim=1)
semantic_token = semantic_token.squeeze(0)
if self.look_ahead_tokens != 0 and self.previous_chunk_left is not None:
semantic_token_previous = self.previous_chunk_left["semantic_token"]
semantic_token = torch.cat(
[semantic_token_previous, semantic_token], dim=-1
)
x_t_chunk = (
torch.randn(semantic_token.shape[0], 80)
.to(semantic_token.device)
.to(self.dtype)
)
if self.look_ahead_tokens != 0 and self.previous_chunk_left is None:
self.previous_chunk_left = {"semantic_token": None}
speech_mel = self.semantic_fm.infer_chunk(
xt_chunk=x_t_chunk,
semantic_tokens_chunk=semantic_token,
start_position_id=self.semantic_fm.start_position_id,
ode_steps=ode_step,
verbose=verbose,
look_ahead_tokens=(
self.look_ahead_tokens * upsample_factor if not is_final else 0
),
cache=self.previous_chunk_left,
ode_solver=ode_solver,
)
chunk_size = speech_mel.shape[0]
length = speech_mel.shape[0]
self.semantic_fm.start_position_id += length
self.semantic_fm.update_incremental_state()
self.semantic_fm.reserve_kv_cache_tokens += (
self.semantic_fm.ode_wrapper.kv_cache_tokens
)
# smoothing
# I will maintain the history of seqlen wav
# For the first chunk, I will only return the half chunk wav, and save the res half chunk in history
# For the rest requests, I will concat the generated wav with the history, output one chunk of the history, save the
if self.pre_mel is None: # first chunk, related to TTFB
concat_mel = speech_mel
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
if is_final:
self.clear_states()
self.state_dict_backup = None
ret_wav = concat_reconstructed_wav.float()
else:
reconstructed_wav = concat_reconstructed_wav[
:, : int(self.frame_size * chunk_size // 2)
] # return the first half chunk
self.pre_wav = concat_reconstructed_wav[
:, -int(self.frame_size * chunk_size // 2) :
] # log the last half chunk for next generation step
self.pre_mel = speech_mel[-chunk_size // 2 :, :]
ret_wav = reconstructed_wav.float()
else:
concat_mel = torch.cat([self.pre_mel, speech_mel], dim=0)
concat_reconstructed_wav = self.vocoder.decode_mel(concat_mel)
if is_final:
self.clear_states()
self.state_dict_backup = None
ret_wav = concat_reconstructed_wav.float()
else:
# fetch history
prev_speech_len = self.pre_wav.shape[1]
if concat_reconstructed_wav.shape[1] > prev_speech_len * 2:
gen_speech_len = prev_speech_len * 2
else:
gen_speech_len = concat_reconstructed_wav.shape[1] // 2
reconstructed_wav = concat_reconstructed_wav[
:, :gen_speech_len
] # return the first half chunk
if gen_speech_len not in self.hamming_window_cache:
self.hamming_window_cache[gen_speech_len] = (
torch.hamming_window(gen_speech_len)
.to(self.dtype)
.to(semantic_token.device)
.unsqueeze(0)
)
hamming_window = self.hamming_window_cache[gen_speech_len]
# apply smoothing of the first half chunk
reconstructed_wav[:, : int(gen_speech_len // 2)] = (
self.pre_wav[:, : int(gen_speech_len // 2)]
* hamming_window[:, -int(gen_speech_len // 2) :]
+ reconstructed_wav[:, : int(gen_speech_len // 2)]
* hamming_window[:, : int(gen_speech_len // 2)]
)
res_speech_len = concat_reconstructed_wav.shape[1] - gen_speech_len
res_mel_len = res_speech_len // self.frame_size
self.pre_wav = concat_reconstructed_wav[:, -res_speech_len:]
self.pre_mel = speech_mel[-res_mel_len:, :]
ret_wav = reconstructed_wav.float()
if (
not is_final
and self.semantic_fm.start_position_id + 2 * chunk_size > self.max_pos_size
):
# out of position id,
self.semantic_fm.clear_all_states()
self.semantic_fm.load_state_dict(self.state_dict_backup)
return ret_wav
def clear_states(self):
self.semantic_fm.clear_all_states()
self.previous_chunk_left = None
self.pre_mel = None
self.pre_wav = None
def get_audio_detokenizer(model_path):
fm_model_config = os.path.join(model_path, "audio_detokenizer", "config.yaml")
fm_ckpt_path = os.path.join(model_path, "audio_detokenizer", "model.pt")
bigvgan_config_file = os.path.join(model_path, "vocoder", "config.json")
bigvgan_ckpt_path = os.path.join(model_path, "vocoder", "model.pt")
device = torch.cuda.current_device()
detokenizer = PrefixStreamingFlowMatchingDetokenizer.from_pretrained(
vocoder_config=bigvgan_config_file,
vocoder_ckpt=bigvgan_ckpt_path,
max_prompt_chunk=10, # 10 * 3 = 30s
fm_config=fm_model_config,
fm_ckpt=fm_ckpt_path,
device=device,
use_cfg=False,
look_ahead_tokens=12,
)
return detokenizer
def detokenize(detokenizer, tokens, ref_wav, ref_tokens):
with torch.no_grad():
detokenizer.clear_states()
detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(
first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size
)
cache_speech_collection.append(gen_speech)
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i : i + chunk_size]
gen_speech = detokenizer.detokenize_streaming(
chunk_tokens, is_final=(i + chunk_size >= res_tokens.size(1))
)
cache_speech_collection.append(gen_speech)
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
return gen_speech_all
def detokenize_streaming(detokenizer, tokens, ref_wav, ref_tokens):
with torch.no_grad():
detokenizer.clear_states()
detokenizer.prefill(ref_wav, ref_tokens, chunk_size=150)
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(
first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size
)
yield gen_speech
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i : i + chunk_size]
gen_speech = detokenizer.detokenize_streaming(
chunk_tokens, is_final=(i + chunk_size >= res_tokens.size(1))
)
yield gen_speech
def detokenize_noref(detokenizer, tokens):
with torch.no_grad():
detokenizer.clear_states()
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(
first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size
)
cache_speech_collection.append(gen_speech)
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i : i + chunk_size]
gen_speech = detokenizer.detokenize_streaming(
chunk_tokens, is_final=(i + chunk_size >= res_tokens.size(1))
)
cache_speech_collection.append(gen_speech)
gen_speech_all = torch.cat(cache_speech_collection, dim=-1)
return gen_speech_all
def detokenize_noref_streaming(detokenizer, tokens):
with torch.no_grad():
detokenizer.clear_states()
cache_speech_collection = []
chunk_size = 150
first_chunk_size = 100
first_chunk_tokens = tokens[:, :first_chunk_size]
gen_speech = detokenizer.detokenize_streaming(
first_chunk_tokens, is_final=tokens.size(1) <= first_chunk_size
)
yield gen_speech
res_tokens = tokens[:, first_chunk_size:]
for i in range(0, res_tokens.size(1), chunk_size):
chunk_tokens = res_tokens[:, i : i + chunk_size]
gen_speech = detokenizer.detokenize_streaming(
chunk_tokens, is_final=(i + chunk_size >= res_tokens.size(1))
)
yield gen_speech
import os
import json
import logging
import librosa
import torch
from .vocoder.bigvgan import BigVGAN
from .vocoder.utils import get_melspec, AttrDict, load_checkpoint
logger = logging.getLogger(__name__)
class BigVGANWrapper:
def __init__(
self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None
) -> None:
self.vocoder = vocoder.to(device)
if dtype is not None:
self.vocoder = self.vocoder.to(dtype)
self.vocoder = self.vocoder.eval()
self.device = device
self.h = h
def to_dtype(self, dtype):
self.vocoder = self.vocoder.to(dtype)
def extract_mel_from_wav(self, wav_path=None, wav_data=None):
"""
params:
wav_path: str, path of the wav, should be 24k
wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k
return:
mel: [T, num_mels], torch.tensor
"""
if wav_data is None:
wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"])
wav_data = torch.tensor(wav_data).unsqueeze(0)
mel = get_melspec(
y=wav_data,
n_fft=self.h["n_fft"],
num_mels=self.h["num_mels"],
sampling_rate=self.h["sampling_rate"],
hop_size=self.h["hop_size"],
win_size=self.h["win_size"],
fmin=self.h["fmin"],
fmax=self.h["fmax"],
)
return mel.squeeze(0).transpose(0, 1)
@torch.inference_mode()
def extract_mel_from_wav_batch(self, wav_data):
"""
params:
wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k
return:
mel: [Batch, T, num_mels], torch.tensor
"""
wav_data = torch.tensor(wav_data)
mel = get_melspec(
wav=wav_data,
n_fft=self.h["n_fft"],
num_mels=self.h["num_mels"],
sampling_rate=self.h["sampling_rate"],
hop_size=self.h["hop_size"],
win_size=self.h["win_size"],
fmin=self.h["fmin"],
fmax=self.h["fmax"],
)
return mel.transpose(1, 2)
def decode_mel(self, mel):
"""
params:
mel: [T, num_mels], torch.tensor
return:
wav: [1, T], torch.tensor
"""
mel = mel.transpose(0, 1).unsqueeze(0).to(self.device)
wav = self.vocoder(mel)
return wav.squeeze(0)
def decode_mel_batch(self, mel):
"""
params:
mel: [B, T, num_mels], torch.tensor
return:
wav: [B, 1, T], torch.tensor
"""
mel = mel.transpose(1, 2).to(self.device)
wav = self.vocoder(mel)
return wav
@classmethod
def from_pretrained(cls, model_config, ckpt_path, device):
with open(model_config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
vocoder = BigVGAN(h, True)
state_dict_g = load_checkpoint(ckpt_path, "cpu")
vocoder.load_state_dict(state_dict_g["generator"])
logger.info(">>> Load vocoder from {}".format(ckpt_path))
return cls(vocoder, device, h)
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2
# the last shape is "self.hidden_dim / 2" because we convert to complex
assert x.ndim == 4
assert freqs_cis.shape == (
x.shape[0],
x.shape[1],
x.shape[-1],
), f"x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}"
# reshape freq cis to match and apply pointwise multiply
# new shape: bsz, seq_len, 1, self.head_hidden_dim / 2
shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
flash_attention: bool = True,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = flash_attention
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.qk_norm = qk_norm
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
x: torch.Tensor,
seq_len,
cu_seqlens,
max_seqlen,
cu_seqlens_k,
max_seqlen_k,
rotary_pos_emb=None,
incremental_state=None,
nopadding=True,
) -> torch.Tensor:
B, N, C = x.shape
if self.fused_attn:
if nopadding:
qkv = self.qkv(x)
qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)
q, k, v = qkv.split([self.num_heads] * 3, dim=1)
q, k = self.q_norm(q), self.k_norm(k)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
q, k = apply_rotary_emb(q, k, rotary_pos_emb)
if incremental_state is not None:
if "prev_k" in incremental_state:
prev_k = incremental_state["prev_k"]
k = torch.cat([prev_k, k], dim=1)
if "cur_k" not in incremental_state:
incremental_state["cur_k"] = {}
incremental_state["cur_k"] = k
if "prev_v" in incremental_state:
prev_v = incremental_state["prev_v"]
v = torch.cat([prev_v, v], dim=1)
if "cur_v" not in incremental_state:
incremental_state["cur_v"] = {}
incremental_state["cur_v"] = v
q = q.view(B * N, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
x = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen_k,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
if incremental_state is not None:
raise NotImplementedError(
"It is designed for batching inference. AR-chunk is not supported currently."
)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if self.qk_norm:
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
# re-bind
qkv = torch.stack((q, k, v), dim=2)
# pack qkv with seq_len
qkv_collect = []
for i in range(qkv.shape[0]):
qkv_collect.append(qkv[i, : seq_len[i], :, :, :])
qkv = torch.cat(qkv_collect, dim=0)
x = flash_attn_varlen_qkvpacked_func(
qkv=qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
# unpack and pad 0
x_collect = []
for i in range(B):
x_collect.append(x[cu_seqlens[i] : cu_seqlens[i + 1], :, :])
x = torch.nn.utils.rnn.pad_sequence(
x_collect, batch_first=True, padding_value=0
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
ffn_type="conv1d_conv1d",
ffn_gated_glu=True,
ffn_act_layer="gelu",
ffn_conv_kernel_size=5,
**block_kwargs,
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
if ffn_type == "vanilla_mlp":
from timm.models.vision_transformer import Mlp
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
else:
raise NotImplementedError(f"FFN type {ffn_type} is not implemented")
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(
self,
x,
c,
seq_len,
cu_seqlens,
cu_maxlen,
cu_seqlens_k,
cu_maxlen_k,
mask,
rotary_pos_emb=None,
incremental_state=None,
nopadding=True,
):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(c).chunk(6, dim=2)
)
x_ = modulate(self.norm1(x), shift_msa, scale_msa)
if incremental_state is not None:
if "attn_kvcache" not in incremental_state:
incremental_state["attn_kvcache"] = {}
inc_attn = incremental_state["attn_kvcache"]
else:
inc_attn = None
x_ = self.attn(
x_,
seq_len=seq_len,
cu_seqlens=cu_seqlens,
max_seqlen=cu_maxlen,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=cu_maxlen_k,
rotary_pos_emb=rotary_pos_emb,
incremental_state=inc_attn,
nopadding=nopadding,
)
if not nopadding:
x_ = x_ * mask[:, :, None]
x = x + gate_msa * x_
x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)
x_ = self.mlp(x_)
if not nopadding:
x_ = x_ * mask[:, :, None]
x = x + gate_mlp * x_
return x
import torch
import torch.nn as nn
import math
from .dit_block import DiTBlock, FinalLayer
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
interpolation_factor: int = 1,
max_seq_length: int = 4096,
):
print(
f"using rope base theta = {theta}, interpolation factor = {interpolation_factor}"
)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# ROPE type-A extention
# we choose to use interpolation rather than extrapolation for better position encoding
# for scale purposes, t should be a float tensor
t = torch.arange(end, device=freqs.device).float()
scale = 1.0 / float(interpolation_factor)
t *= scale
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb
# e.g. rope 1M but seqlen 32k, this will cause gpu memory waste
if max_seq_length < end:
freqs_cis = freqs_cis[:max_seq_length,].clone()
return freqs_cis
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = (
torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
)
.float()
.to(device=t.device)
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2 # d/2
emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2)
emb = torch.exp(
torch.arange(half_dim, dtype=torch.float) * -emb
) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(
0
) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
) # shape: (num_embeddings, d)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = self.make_positions(input, self.padding_idx)
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
) # (B, T, dim)
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
def make_positions(self, tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
class DiTPrefix(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size,
output_size,
semantic_vocab_size,
hidden_size=1024,
depth=12,
num_heads=4,
# mlp related
mlp_ratio=4.0,
ffn_type="conv1d_conv1d",
ffn_gated_glu=True,
ffn_act_layer="gelu",
ffn_conv_kernel_size=5,
# rope
use_rope=False,
rope_params={
"max_position_embeddings": 4096,
"rope_base": 10000.0,
"rope_interpolation_factor": 1.0,
},
position_embedding_type="sincos",
max_seq_len=4096,
prompt_cfg_dropout=0.0,
):
super().__init__()
self.num_heads = num_heads
self.prompt_cfg_dropout = prompt_cfg_dropout
self.t_embedder = TimestepEmbedder(hidden_size)
self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)
self.input_linear = nn.Linear(input_size, hidden_size)
# position embedding
if position_embedding_type == "learnable":
self.position_embedding = nn.Embedding(max_seq_len + 1, hidden_size)
elif position_embedding_type == "sincos":
self.position_embedding = SinusoidalPositionalEmbedding(
hidden_size, 0, max_seq_len + 1
)
elif position_embedding_type == "skip":
self.position_embedding = None
else:
raise NotImplementedError(
"Position embedding type: {} not implemented.".format(
position_embedding_type
)
)
self.use_rope = use_rope
if self.use_rope:
assert (
hidden_size % num_heads == 0
), "Hidden size must be divisible by num_heads for rope position embedding."
rope_dim = hidden_size // num_heads
self.rotary_pos_emb = precompute_freqs_cis(
rope_dim,
rope_params["max_position_embeddings"],
theta=rope_params["rope_base"],
interpolation_factor=rope_params["rope_interpolation_factor"],
max_seq_length=max_seq_len,
)
self.blocks = nn.ModuleList(
[
DiTBlock(
hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
ffn_type=ffn_type,
ffn_conv_kernel_size=ffn_conv_kernel_size,
ffn_gated_glu=ffn_gated_glu,
ffn_act_layer=ffn_act_layer,
)
for _ in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, output_size)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(
self,
x,
position_ids,
t,
condition,
seq_len,
cu_seqlens,
cu_maxlen,
cu_seqlens_k,
cu_maxlen_k,
mask,
incremental_state=None,
nopadding=True,
):
"""
Forward pass of DiT.
x: (N, T, C) tensor of inputs (latent representations of speech)
position_ids: (N, T) tensor of positional indices
t: (N,) tensor of diffusion timesteps
condition: (N, T) tensor of semantic tokens
seq_len: (N,) tensor of sequence lengths
"""
condition = self.semantic_token_embedding(condition) # (N, T, D)
x = self.input_linear(x)
if self.position_embedding is not None:
position_emb = self.position_embedding(position_ids)
x = x + position_emb
# ROPE
if self.use_rope:
bsz, seqlen = position_ids.shape
if self.rotary_pos_emb.device != position_ids.device:
self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)
rotary_pos_emb = torch.zeros(
(bsz, seqlen, self.rotary_pos_emb.shape[1]),
dtype=self.rotary_pos_emb.dtype,
device=self.rotary_pos_emb.device,
)
for b in range(bsz):
cur_rope = rotary_pos_emb[b]
cur_position_ids = position_ids[b]
cur_rope[:] = self.rotary_pos_emb[cur_position_ids]
else:
rotary_pos_emb = None
t = self.t_embedder(t) # (N, D)
c = t.unsqueeze(1) + condition # (N, T, D)
for block_idx, block in enumerate(self.blocks):
# x = block(x, c, attn_mask) # (N, T, D)
# XXX mask could be None because we always use full mask
if incremental_state is not None:
if block_idx not in incremental_state:
incremental_state[block_idx] = {}
incr = incremental_state[block_idx]
else:
incr = None
x = block(
x=x,
c=c,
seq_len=seq_len,
cu_seqlens=cu_seqlens,
cu_maxlen=cu_maxlen,
cu_seqlens_k=cu_seqlens_k,
cu_maxlen_k=cu_maxlen_k,
mask=mask,
rotary_pos_emb=rotary_pos_emb,
incremental_state=incr,
nopadding=nopadding,
)
x = self.final_layer(x, c) # (N, T, C)
return x
import torch
import torch.nn as nn
from functools import lru_cache
import copy
@lru_cache(maxsize=1)
def get_cached_zeros(numel, device="cpu", dtype=torch.float32):
return torch.zeros(numel, device=device, dtype=dtype)
class StreamingODEWrapperForPrefix(nn.Module):
def __init__(
self,
net,
x_mask,
x_cond,
use_cfg=False,
use_cfg_rescale=True,
cfg_init=1.0,
cfg_scale=4.0,
cfg_schedule="linear",
cfg_token_id=0,
):
super(StreamingODEWrapperForPrefix, self).__init__()
self.net = net
self.x_mask = x_mask
self.x_cond = x_cond
assert use_cfg == False, "cfg is not supported in streaming detokenizer"
self.use_cfg = use_cfg
self.use_cfg_rescale = use_cfg_rescale
self.cfg_init = cfg_init
self.cfg_scale = cfg_scale
self.cfg_token_id = cfg_token_id
self.cfg_schedule = cfg_schedule
self.position_ids = None
self.seq_len = None
self.incremental_state = {}
self.kv_cache_tokens = 0
self.cu_seqlens = None
self.cu_maxlen = None
self.cu_seqlens_k = None
self.cu_maxlen_k = None
self.previous_seqlen = None
def clear_all_states(self):
self.incremental_state = {}
self.kv_cache_tokens = 0
self.cu_seqlens = None
self.cu_maxlen = None
self.cu_seqlens_k = None
self.cu_maxlen_k = None
self.previous_seqlen = None
def state_dict(self):
return {
"incremental_state": copy.deepcopy(self.incremental_state),
"kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens),
"cu_seqlens": copy.deepcopy(self.cu_seqlens),
"cu_maxlen": copy.deepcopy(self.cu_maxlen),
"cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k),
"cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k),
"previous_seqlen": copy.deepcopy(self.previous_seqlen),
}
def load_state_dict(self, state_dict):
self.incremental_state = state_dict["incremental_state"]
self.kv_cache_tokens = state_dict["kv_cache_tokens"]
self.cu_seqlens = state_dict["cu_seqlens"]
self.cu_maxlen = state_dict["cu_maxlen"]
self.cu_seqlens_k = state_dict["cu_seqlens_k"]
self.cu_maxlen_k = state_dict["cu_maxlen_k"]
self.previous_seqlen = state_dict["previous_seqlen"]
def set_conditions(self, x_mask, x_cond, start_position_id, cache={}):
if not self.use_cfg:
self.x_mask = x_mask
self.x_cond = x_cond
else:
self.x_cond = torch.cat((x_cond, x_cond), dim=0)
self.x_mask = torch.cat((x_mask, x_mask), dim=0)
position_ids_cur = [
i
for i in range(start_position_id, self.x_cond.shape[1] + start_position_id)
]
position_ids = torch.tensor([position_ids_cur])
if not self.use_cfg:
self.position_ids = position_ids.to(self.x_cond.device).long()
self.seq_len = (
torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long()
)
else:
self.position_ids = (
torch.cat((position_ids, position_ids), dim=0)
.to(self.x_cond.device)
.long()
)
self.seq_len = (
torch.Tensor([position_ids.shape[1], position_ids.shape[1]])
.to(self.x_cond.device)
.long()
)
cu_seqlens = torch.cumsum(self.seq_len, dim=0)
self.cu_seqlens = torch.cat(
[torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0
).int()
self.cu_maxlen = self.seq_len.cpu().max()
if self.cu_seqlens_k is None:
self.cu_seqlens_k = self.cu_seqlens
self.cu_maxlen_k = self.cu_maxlen
previous_seqlen = self.seq_len
else:
previous_seqlen_old = cache["previous_seqlen"]
previous_seqlen = previous_seqlen_old + self.seq_len
# calculate cu_seqlens_k
cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0)
self.cu_seqlens_k = torch.cat(
[torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0
).int()
self.cu_maxlen_k = previous_seqlen.cpu().max()
self.previous_seqlen = previous_seqlen
ret_cache = {"previous_seqlen": previous_seqlen}
return ret_cache
def update_incremental_state(
self,
reserve_kv_cache_tokens=0,
max_kv_cache_tokens=900,
condition_cache={"previous_seqlen"},
):
assert (
reserve_kv_cache_tokens <= max_kv_cache_tokens
), "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens"
for layer_idx, layer_cache in self.incremental_state.items():
# update attention kv cache
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"]
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
if self.kv_cache_tokens > max_kv_cache_tokens:
# drop old tokens from reserve kv cache tokens to max_kv_cache_tokens
reserve_tokens_excludeprompt = (
max_kv_cache_tokens - reserve_kv_cache_tokens
)
if reserve_kv_cache_tokens == 0:
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"][
"prev_k"
][:, -reserve_tokens_excludeprompt:]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"][
"prev_v"
][:, -reserve_tokens_excludeprompt:]
elif reserve_tokens_excludeprompt == 0:
layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"][
"prev_k"
][:, :reserve_kv_cache_tokens]
layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"][
"prev_v"
][:, :reserve_kv_cache_tokens]
else:
layer_cache["attn_kvcache"]["prev_k"] = torch.cat(
[
layer_cache["attn_kvcache"]["prev_k"][
:, :reserve_kv_cache_tokens
],
layer_cache["attn_kvcache"]["prev_k"][
:, -reserve_tokens_excludeprompt:
],
],
dim=1,
)
layer_cache["attn_kvcache"]["prev_v"] = torch.cat(
[
layer_cache["attn_kvcache"]["prev_v"][
:, :reserve_kv_cache_tokens
],
layer_cache["attn_kvcache"]["prev_v"][
:, -reserve_tokens_excludeprompt:
],
],
dim=1,
)
bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0]
self.previous_seqlen = (
torch.Tensor(
[
layer_cache["attn_kvcache"]["prev_k"].shape[1]
for i in range(bsz)
]
)
.to(layer_cache["attn_kvcache"]["prev_k"].device)
.long()
)
condition_cache["previous_seqlen"] = self.previous_seqlen
self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1]
# clear current cache
layer_cache["attn_kvcache"].pop("cur_k")
layer_cache["attn_kvcache"].pop("cur_v")
def forward(self, t, x, args=None):
# t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long()
t = (
get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long)
+ (t * 1000).long()
)
if self.use_cfg:
raise NotImplementedError("cfg is not supported in streaming detokenizer.")
else:
pred_noise = self.net(
x=x,
condition=self.x_cond,
t=t,
position_ids=self.position_ids,
cu_seqlens=self.cu_seqlens,
cu_maxlen=self.cu_maxlen,
cu_seqlens_k=self.cu_seqlens_k,
cu_maxlen_k=self.cu_maxlen_k,
incremental_state=self.incremental_state,
nopadding=True,
mask=None,
seq_len=None,
)
return pred_noise
import torch
from abc import abstractmethod, ABC
try:
from torchdyn.core import NeuralODE
NEURALODE_INSTALLED = True
except ImportError:
NEURALODE_INSTALLED = False
class SchedulerBase(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def set_timesteps(self):
pass
@abstractmethod
def step(self):
pass
@abstractmethod
def add_noise(self):
pass
class StreamingFlowMatchingScheduler(SchedulerBase):
def __init__(
self,
timesteps=1000,
sigma_min=1e-4,
) -> None:
super().__init__()
self.sigma_min = sigma_min
self.timesteps = timesteps
self.t_min = 0
self.t_max = 1 - self.sigma_min
self.neural_ode = None
def set_timesteps(self, timesteps=15):
self.timesteps = timesteps
def step(self, xt, predicted_v):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
xt = xt + h * predicted_v
return xt
def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
h = (self.t_max - self.t_min) / self.timesteps
h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device)
if verbose:
gt_v = x0 - xt
for t in time_steps:
predicted_v = ode_wrapper(t, xt)
if verbose:
dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v))
print("Time: {}, Distance: {}".format(t, dist))
xt = xt + h * predicted_v
return xt
def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None):
if not NEURALODE_INSTALLED:
raise ImportError("NeuralODE is not installed, please install it first.")
if self.neural_ode is None:
self.neural_ode = NeuralODE(
ode_wrapper,
solver="euler",
sensitivity="adjoint",
atol=self.sigma_min,
rtol=self.sigma_min,
)
eval_points, traj = self.neural_ode(xt, time_steps)
return traj[-1]
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
):
ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系
t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps
x_noisy = (
t_unsqueeze * original_samples
+ (1.0 - (1 - self.sigma_min) * t_unsqueeze) * noise
)
return x_noisy, ut
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment