Commit 3344c738 authored by zjsun's avatar zjsun 💬
Browse files

Merge branch 'my-feature-branch' into 'main'

My feature branch

See merge request !1
parents d0ceb4e4 85f0282a
.pgx.*
.pdm-python
/speech_lm.egg-info
__pycache__
/results
ci:
autoupdate_schedule: monthly
repos:
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
hooks:
- id: codespell
files: ^.*\.(py|md|rst|yml)$
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu" :"auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"stage3_gather_16bit_weights_on_model_save": true
},
"flops_profiler": {
"enabled": false,
"profile_step": 1,
"module_depth": -1,
"top_modules": 1,
"detailed": true,
"output_file": null
}
}
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import json
import os
from pathlib import Path
import librosa
import torch
from datasets import Dataset
from multiprocess import set_start_method
from transformers import AutoProcessor, EncodecModel
set_start_method("spawn", force=True)
encodec_name = "facebook/encodec_24khz"
encodec_processor = AutoProcessor.from_pretrained(encodec_name)
encodec_model = EncodecModel.from_pretrained(encodec_name)
encodec_model.eval()
def tokenize(text, audio, sr=None, speaker=None):
assert sr is None or sr == encodec_processor.sampling_rate
if isinstance(audio, (str, Path)):
audio, sr = librosa.load(audio, sr=sr, mono=True)
prompt = "[INST] "
if speaker:
prompt += f"[SPK] {speaker} [/SPK] "
prompt += f"{text} [/INST] "
inputs = encodec_processor(
raw_audio=audio, sampling_rate=sr, return_tensors="pt"
).to(encodec_model.device)
outputs = encodec_model.encode(
inputs["input_values"], inputs["padding_mask"], bandwidth=1.5, return_dict=True
)
assert outputs.audio_codes.dim() == 4 # [batch, channel, codebook, code]
assert outputs.audio_codes.shape[0] == outputs.audio_codes.shape[1] == 1
codes = outputs.audio_codes[0, 0, 0, :].long()
codes_str = " ".join([f"<encodec_{int(c)}>" for c in codes.tolist()])
prompt += codes_str
return {
"prompt": prompt,
"codes": codes,
}
def wrap_tokenize(x):
device = torch.device("cuda", 0)
if encodec_model.device != device:
encodec_model.to(device)
return tokenize(
text=x["text"],
audio=x["raw_path"],
sr=encodec_processor.sampling_rate,
speaker=x["speaker"],
)
def generator_libritts_r():
base = Path("dataset/tts/LibriTTS_R")
for i in base.rglob("*.wav"):
text_file = i.with_suffix(".normalized.txt")
if not text_file.exists():
continue
text = text_file.read_text().strip()
yield {
"text": text,
"speaker": f"libritts_{i.parent.parent.name}",
"raw_path": str(i),
"path": str(i.relative_to(base)),
}
if __name__ == "__main__":
dataset = Dataset.from_generator(generator_libritts_r)
dataset = dataset.map(wrap_tokenize, num_proc=12)
dataset = dataset.remove_columns(["raw_path"])
dataset.save_to_disk("dataset/tts/libritts-r-encodec")
dataset.push_to_hub("fishaudio/libritts-r-encodec", private=True)
import random
import subprocess
from multiprocessing import Pool, cpu_count
from pathlib import Path
from tqdm import tqdm
def convert_to_flac(src_file_path):
dst_file_path = src_file_path.with_suffix(".flac")
dst_file_path.parent.mkdir(parents=True, exist_ok=True)
try:
subprocess.check_call(
[
"ffmpeg",
"-y",
"-i",
str(src_file_path),
"-acodec",
"flac",
"-threads",
"0",
str(dst_file_path),
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# remove the input file
src_file_path.unlink()
return True
except subprocess.CalledProcessError:
return False
if __name__ == "__main__":
src_dir = Path("dataset/tts/WenetSpeech/cleaned")
wav_files = list(src_dir.rglob("*.wav"))
random.shuffle(wav_files)
print(f"Found {len(wav_files)} wav files")
success_counter = 0
fail_counter = 0
with Pool(processes=cpu_count(), maxtasksperchild=100) as pool:
with tqdm(
pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)
) as pbar:
for success in pbar:
if success:
success_counter += 1
else:
fail_counter += 1
pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}")
print(f"Successfully converted: {success_counter}")
print(f"Failed conversions: {fail_counter}")
import json
import os
import subprocess
import tempfile
import time
from pathlib import Path
import librosa
import soundfile as sf
import torch
import torchaudio
from fish_audio_preprocess.utils.separate_audio import (
init_model,
merge_tracks,
separate_audio,
)
from tqdm import tqdm
rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
device = torch.device("cuda:0")
print(f"Rank {rank}/{world_size} on {device}")
def main():
meta_path = Path("dataset/tts/WenetSpeech/WenetSpeech.json")
dataset_path = Path("dataset/tts/WenetSpeech")
cleaned_path = Path("dataset/tts/WenetSpeech/cleaned")
if not cleaned_path.exists():
cleaned_path.mkdir(parents=True)
demucs = init_model("htdemucs", device)
print("Model loaded")
with open(meta_path) as f:
dataset = json.load(f)["audios"]
print(f"Dataset loaded, {len(dataset)} samples")
dataset = dataset[rank::world_size]
print(f"Dataset split, {len(dataset)} samples")
for data_idx, data in enumerate(dataset):
done_path = cleaned_path / data["aid"] / "done"
done_path.parent.mkdir(parents=True, exist_ok=True)
if done_path.exists():
continue
print(f"Processing {data_idx}/{len(dataset)} at rank {rank}")
try:
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
subprocess.check_call(
[
"ffmpeg",
"-y",
"-i",
str(dataset_path / data["path"]),
"-c:a",
"pcm_s16le",
"-threads",
"0",
"-ar",
"24000",
str(f.name),
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
raw_audio, sr = librosa.load(f.name, sr=None, mono=True)
raw_audio = torch.from_numpy(raw_audio[None]).to(device)
audio = torchaudio.functional.resample(
raw_audio, orig_freq=sr, new_freq=demucs.samplerate
)
# Make it 2 channels
audio = torch.cat([audio, audio], dim=0)
tracks = separate_audio(
demucs, audio, shifts=1, num_workers=0, progress=False
)
audio = merge_tracks(tracks, filter=["vocals"])[0]
vocals, sr = (
torchaudio.functional.resample(
audio, orig_freq=demucs.samplerate, new_freq=24000
),
24000,
)
vocals = vocals.cpu().numpy()
for idx, segment in enumerate(data["segments"]):
if segment["confidence"] <= 0.95:
continue
# Load audio
begin = int(segment["begin_time"] * sr)
end = int(segment["end_time"] * sr)
segment_audio = vocals[begin:end]
# Write audio
temp_path = cleaned_path / data["aid"] / f"S{idx:05d}.wav"
temp_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(temp_path, segment_audio, samplerate=sr)
# Write text
temp_path = temp_path.with_suffix(".txt")
temp_path.write_text(segment["text"])
# Write done file
done_path.write_text("")
except Exception as e:
print(f"Error {e} on {data_idx}/{len(dataset)} at rank {rank}")
time.sleep(10)
continue
print("Done")
if __name__ == "__main__":
main()
import tarfile
from pathlib import Path
from tqdm import tqdm
import io
import random
from multiprocessing import Process
def chunked_tarring(rank, file_list, base_folder, output_folder, chunk_size=1024**3):
chunk_count = 1
total_size = 0
saved_count = 0
buffer = io.BytesIO()
tar = tarfile.open(fileobj=buffer, mode="w")
for audio_file in file_list:
txt_file = audio_file.with_suffix(".txt")
if not txt_file.exists():
continue
file_size = audio_file.stat().st_size + txt_file.stat().st_size
if total_size + file_size > chunk_size:
tar.close()
# write the buffer to disk
buffer.seek(0)
with open(output_folder / f"chunk-{rank:03d}-{chunk_count:04d}.tar", "wb") as f:
f.write(buffer.read())
chunk_count += 1
total_size = 0
buffer.close()
buffer = io.BytesIO()
tar = tarfile.open(fileobj=buffer, mode="w")
tar.add(audio_file, arcname=audio_file.relative_to(base_folder))
tar.add(txt_file, arcname=txt_file.relative_to(base_folder))
total_size += file_size
if saved_count % 1000 == 0:
print(f"Rank {rank}: {saved_count}/{len(file_list)}")
saved_count += 1
tar.close()
buffer.seek(0)
with open(output_folder / f"chunk-{rank:03d}-{chunk_count:04d}.tar", "wb") as f:
f.write(buffer.read())
print(f"Rank {rank}: {saved_count}/{len(file_list)}")
if __name__ == "__main__":
base_folder = Path("/mnt/nvme1/multi-modal-test/WenetSpeech/cleaned")
output_folder = Path("/mnt/nvme1/multi-modal-test/WenetSpeech/compressed")
output_folder.mkdir(exist_ok=True, parents=True)
num_workers = 50
file_list = list(tqdm(base_folder.rglob("*.flac")))
random.shuffle(file_list)
print(f"Total files: {len(file_list)}")
chunk_size = len(file_list) // num_workers
processes = []
for i in range(num_workers):
start = i * chunk_size
end = (i + 1) * chunk_size
if i == num_workers - 1:
end = len(file_list)
p = Process(target=chunked_tarring, args=(i, file_list[start:end], base_folder, output_folder))
p.start()
processes.append(p)
for p in processes:
p.join()
print("Done")
import os
import subprocess as sp
import sys
SLURM_NTASKS = 6
processes = []
for i in range(SLURM_NTASKS):
env = os.environ.copy()
env["SLURM_PROCID"] = str(i)
env["SLURM_NTASKS"] = str(SLURM_NTASKS)
env["CUDA_VISIBLE_DEVICES"] = str(i % 8)
processes.append(
sp.Popen(
f"python preparing_data/wenet_clean/clean_wenet_speech.py",
shell=True,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
)
for p in processes:
p.wait()
print(p.communicate())
[project]
name = "speech-lm"
version = "0.0.1"
description = ""
authors = [
{name = "Lengyue", email = "lengyue@lengyue.me"},
]
dependencies = [
"torch>=2.1.0",
"torchvision>=0.16.0",
"torchaudio>=2.1.0",
"transformers>=4.34.0",
"datasets>=2.14.5",
"bitsandbytes>=0.41.1",
"peft>=0.5.0",
"deepspeed>=0.11.1",
"lightning>=2.0.9.post0",
"hydra-core>=1.3.2",
"pyrootutils>=1.0.4",
]
requires-python = ">=3.10"
license = {text = "MIT"}
[tool.setuptools]
py-modules = ["speech_lm"]
[tool.pdm.dev-dependencies]
dev = [
"black>=23.9.1",
"isort>=5.12.0",
]
paths:
run_dir: results/pretrain
hydra:
run:
dir: ${paths.run_dir}
import random
from transformers import AutoTokenizer
from datasets import load_dataset, interleave_datasets, IterableDataset
from functools import lru_cache
from torch.utils.data import DataLoader
from datasets.distributed import split_dataset_by_node
from torch.distributed import get_rank, get_world_size, is_initialized
@lru_cache(maxsize=1)
def get_tokenizer():
return AutoTokenizer.from_pretrained("fishaudio/speech-lm-300m", revision="init")
def encode(examples):
# Random choice a 512 token window for each example
texts = []
for text in examples["text"]:
if len(text) <= 512:
texts.append(text)
else:
start = random.randint(0, len(text) - 512)
texts.append(text[start : start + 512])
data = get_tokenizer()(
texts,
truncation=True,
padding="max_length",
max_length=512,
)
data["labels"] = data["input_ids"].copy()
data["labels"][data["attention_mask"] == 0] = -100
return data
def build_dataset():
en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)
multilingual_dataset: IterableDataset = interleave_datasets(
[en_dataset, ja_dataset, zh_dataset], probabilities=[0.4, 0.3, 0.3], seed=42
)
# DDP
if is_initialized():
multilingual_dataset = split_dataset_by_node(
multilingual_dataset,
rank=get_rank(),
num_replicas=get_world_size(),
)
multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)
multilingual_dataset = multilingual_dataset.map(
encode, batched=True, remove_columns=multilingual_dataset.column_names
)
return multilingual_dataset
if __name__ == "__main__":
dataset = build_dataset()
print(list(dataset.take(16)))
from transformers import LlamaModel, LlamaConfig, AutoTokenizer
# reuse the tokenizer from the llama
model_type = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_type)
# new tokens
new_tokens = [f"<semantic_{i}>" for i in range(4096)]
tokenizer.add_tokens(new_tokens + ["<pad>"])
# pad token
tokenizer.pad_token = "<pad>"
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
print(f"Vocab size: {len(tokenizer)}")
hidden_size = 1024
intermediate_size = hidden_size * (11 / 3)
# then round to the nearest multiple of 8
intermediate_size = round(intermediate_size / 8) * 8
print(f"Hidden size: {hidden_size}")
print(f"Intermediate size: {intermediate_size}")
model = LlamaModel(
LlamaConfig(
vocab_size=tokenizer.vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=20,
num_attention_heads=16,
max_position_embeddings=4096,
)
)
model = model.bfloat16()
# Resize the token embeddings to include the new tokens
# Make sure it's a multiple of 8 for faster training
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6:.2f}M")
# Try tokenizing a new sequence
sequence = "Test <semantic_0> <semantic_1023> <pad>"
encoded = tokenizer.encode(sequence)
print("Test encoding....")
print(f"\tSentence: {sequence}")
print(f"\tEncoded: {encoded}")
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
# model.save_pretrained("./checkpoints/speech-lm-300m-init")
# tokenizer.save_pretrained("./checkpoints/speech-lm-300m-init")
model.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")
tokenizer.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")
import torch
from lightning.fabric import Fabric
import hydra
from omegaconf import DictConfig, OmegaConf
import pyrootutils
# Allow TF32 on Ampere GPUs
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.allow_tf32 = True
# register eval resolver and root
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
OmegaConf.register_new_resolver("eval", eval)
# flake8: noqa: E402
from speech_lm.dataset import build_dataset
@hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
def main(cfg: DictConfig):
print(cfg)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment