Commit 85f0282a authored by Lengyue's avatar Lengyue Committed by zjsun
Browse files

init training code

parent 66ea8ff4
.pgx.*
.pdm-python
/speech_lm.egg-info
__pycache__
/results
This diff is collapsed.
import json
import os
from pathlib import Path
import librosa
import torch
from datasets import Dataset
from multiprocess import set_start_method
from transformers import AutoProcessor, EncodecModel
set_start_method("spawn", force=True)
encodec_name = "facebook/encodec_24khz"
encodec_processor = AutoProcessor.from_pretrained(encodec_name)
encodec_model = EncodecModel.from_pretrained(encodec_name)
encodec_model.eval()
def tokenize(text, audio, sr=None, speaker=None):
assert sr is None or sr == encodec_processor.sampling_rate
if isinstance(audio, (str, Path)):
audio, sr = librosa.load(audio, sr=sr, mono=True)
prompt = "[INST] "
if speaker:
prompt += f"[SPK] {speaker} [/SPK] "
prompt += f"{text} [/INST] "
inputs = encodec_processor(
raw_audio=audio, sampling_rate=sr, return_tensors="pt"
).to(encodec_model.device)
outputs = encodec_model.encode(
inputs["input_values"], inputs["padding_mask"], bandwidth=1.5, return_dict=True
)
assert outputs.audio_codes.dim() == 4 # [batch, channel, codebook, code]
assert outputs.audio_codes.shape[0] == outputs.audio_codes.shape[1] == 1
codes = outputs.audio_codes[0, 0, 0, :].long()
codes_str = " ".join([f"<encodec_{int(c)}>" for c in codes.tolist()])
prompt += codes_str
return {
"prompt": prompt,
"codes": codes,
}
def wrap_tokenize(x):
device = torch.device("cuda", 0)
if encodec_model.device != device:
encodec_model.to(device)
return tokenize(
text=x["text"],
audio=x["raw_path"],
sr=encodec_processor.sampling_rate,
speaker=x["speaker"],
)
def generator_libritts_r():
base = Path("dataset/tts/LibriTTS_R")
for i in base.rglob("*.wav"):
text_file = i.with_suffix(".normalized.txt")
if not text_file.exists():
continue
text = text_file.read_text().strip()
yield {
"text": text,
"speaker": f"libritts_{i.parent.parent.name}",
"raw_path": str(i),
"path": str(i.relative_to(base)),
}
if __name__ == "__main__":
dataset = Dataset.from_generator(generator_libritts_r)
dataset = dataset.map(wrap_tokenize, num_proc=12)
dataset = dataset.remove_columns(["raw_path"])
dataset.save_to_disk("dataset/tts/libritts-r-encodec")
dataset.push_to_hub("fishaudio/libritts-r-encodec", private=True)
......@@ -12,10 +12,12 @@ dependencies = [
"torchaudio>=2.1.0",
"transformers>=4.34.0",
"datasets>=2.14.5",
"accelerate>=0.23.0",
"bitsandbytes>=0.41.1",
"peft>=0.5.0",
"omegaconf>=2.3.0",
"deepspeed>=0.11.1",
"lightning>=2.0.9.post0",
"hydra-core>=1.3.2",
"pyrootutils>=1.0.4",
]
requires-python = ">=3.10"
license = {text = "MIT"}
......
paths:
run_dir: results/pretrain
hydra:
run:
dir: ${paths.run_dir}
import torch
from lightning.fabric import Fabric
import hydra
from omegaconf import DictConfig, OmegaConf
import pyrootutils
# Allow TF32 on Ampere GPUs
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.allow_tf32 = True
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
OmegaConf.register_new_resolver("eval", eval)
# flake8: noqa: E402
from speech_lm.dataset import build_dataset
@hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
def main(cfg: DictConfig):
print(cfg)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment