"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "13aace3d34e7ccc0c4848a9a31814f2673bb6ccd"
Commit 66ea8ff4 authored by Lengyue's avatar Lengyue Committed by zjsun
Browse files

Init pdm & dataset

parent 957b98c0
.pgx.* .pgx.*
.pdm-python
/speech_lm.egg-info
from dataclasses import dataclass, field
from functools import partial
from typing import Optional
from datasets import load_dataset, load_from_disk
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorWithPadding,
HfArgumentParser,
Trainer,
)
from transformers import TrainingArguments as _TrainingArguments
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
@dataclass
class TrainingArguments(_TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
def dataset_transform(batch, tokenizer: AutoTokenizer = None):
outputs = tokenizer(
batch["prompt"],
padding="longest",
truncation=True,
max_length=512,
return_tensors="pt",
)
labels = outputs.input_ids.clone()
# Set the labels to -100 so that the logits are not affected by loss
labels[outputs.attention_mask == 0] = -100
return {
"input_ids": outputs.input_ids,
"attention_mask": outputs.attention_mask,
"labels": labels,
}
def train():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
cache_dir=training_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=False,
trust_remote_code=True,
model_max_length=training_args.model_max_length,
cache_dir=training_args.cache_dir,
)
tokenizer.pad_token_id = tokenizer.eos_token_id
if training_args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["W_pack"],
inference_mode=False,
r=16,
lora_alpha=64,
lora_dropout=0.1,
)
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
try:
dataset = load_from_disk(data_args.data_path)
if "train" in dataset:
dataset = dataset["train"]
except:
dataset = load_dataset(data_args.data_path, split="train")
dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
dataset = dataset.train_test_split(test_size=1000, seed=42)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer),
)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
export NCCL_P2P_DISABLE=1
hostfile=""
deepspeed --hostfile=$hostfile tools/tts/fine-tune.py \
--deepspeed tools/tts/ds_config.json \
--report_to "tensorboard" \
--data_path "fishaudio/libritts-r-encodec" \
--model_name_or_path "checkpoints/llama2-tiny-init" \
--output_dir "results" \
--model_max_length 4096 \
--max_steps 500000 \
--per_device_train_batch_size 32 \
--gradient_accumulation_steps 1 \
--save_strategy steps \
--save_steps 10000 \
--evaluation_strategy steps \
--eval_steps 10000 \
--learning_rate 1e-3 \
--lr_scheduler_type cosine \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-8 \
--max_grad_norm 1.0 \
--weight_decay 1e-4 \
--warmup_steps 10000 \
--logging_steps 1 \
--gradient_checkpointing True \
--remove_unused_columns False \
--use_lora False \
--bf16 True \
--tf32 True
This source diff could not be displayed because it is too large. You can view the blob instead.
import json
import os
from pathlib import Path
import librosa
import torch
from datasets import Dataset
from multiprocess import set_start_method
from transformers import AutoProcessor, EncodecModel
set_start_method("spawn", force=True)
encodec_name = "facebook/encodec_24khz"
encodec_processor = AutoProcessor.from_pretrained(encodec_name)
encodec_model = EncodecModel.from_pretrained(encodec_name)
encodec_model.eval()
def tokenize(text, audio, sr=None, speaker=None):
assert sr is None or sr == encodec_processor.sampling_rate
if isinstance(audio, (str, Path)):
audio, sr = librosa.load(audio, sr=sr, mono=True)
prompt = "[INST] "
if speaker:
prompt += f"[SPK] {speaker} [/SPK] "
prompt += f"{text} [/INST] "
inputs = encodec_processor(
raw_audio=audio, sampling_rate=sr, return_tensors="pt"
).to(encodec_model.device)
outputs = encodec_model.encode(
inputs["input_values"], inputs["padding_mask"], bandwidth=1.5, return_dict=True
)
assert outputs.audio_codes.dim() == 4 # [batch, channel, codebook, code]
assert outputs.audio_codes.shape[0] == outputs.audio_codes.shape[1] == 1
codes = outputs.audio_codes[0, 0, 0, :].long()
codes_str = " ".join([f"<encodec_{int(c)}>" for c in codes.tolist()])
prompt += codes_str
return {
"prompt": prompt,
"codes": codes,
}
def wrap_tokenize(x):
device = torch.device("cuda", 0)
if encodec_model.device != device:
encodec_model.to(device)
return tokenize(
text=x["text"],
audio=x["raw_path"],
sr=encodec_processor.sampling_rate,
speaker=x["speaker"],
)
def generator_libritts_r():
base = Path("dataset/tts/LibriTTS_R")
for i in base.rglob("*.wav"):
text_file = i.with_suffix(".normalized.txt")
if not text_file.exists():
continue
text = text_file.read_text().strip()
yield {
"text": text,
"speaker": f"libritts_{i.parent.parent.name}",
"raw_path": str(i),
"path": str(i.relative_to(base)),
}
if __name__ == "__main__":
dataset = Dataset.from_generator(generator_libritts_r)
dataset = dataset.map(wrap_tokenize, num_proc=12)
dataset = dataset.remove_columns(["raw_path"])
dataset.save_to_disk("dataset/tts/libritts-r-encodec")
dataset.push_to_hub("fishaudio/libritts-r-encodec", private=True)
[project]
name = "speech-lm"
version = "0.0.1"
description = ""
authors = [
{name = "Lengyue", email = "lengyue@lengyue.me"},
]
dependencies = [
"torch>=2.1.0",
"torchvision>=0.16.0",
"torchaudio>=2.1.0",
"transformers>=4.34.0",
"datasets>=2.14.5",
"accelerate>=0.23.0",
"bitsandbytes>=0.41.1",
"peft>=0.5.0",
"omegaconf>=2.3.0",
]
requires-python = ">=3.10"
license = {text = "MIT"}
[tool.setuptools]
py-modules = ["speech_lm"]
[tool.pdm.dev-dependencies]
dev = [
"black>=23.9.1",
"isort>=5.12.0",
]
import random
from transformers import AutoTokenizer
from datasets import load_dataset, interleave_datasets, IterableDataset
from functools import lru_cache
from torch.utils.data import DataLoader
from datasets.distributed import split_dataset_by_node
from torch.distributed import get_rank, get_world_size, is_initialized
@lru_cache(maxsize=1)
def get_tokenizer():
return AutoTokenizer.from_pretrained("fishaudio/speech-lm-300m", revision="init")
def encode(examples):
# Random choice a 512 token window for each example
texts = []
for text in examples["text"]:
if len(text) <= 512:
texts.append(text)
else:
start = random.randint(0, len(text) - 512)
texts.append(text[start : start + 512])
data = get_tokenizer()(
texts,
truncation=True,
padding="max_length",
max_length=512,
)
data["labels"] = data["input_ids"].copy()
data["labels"][data["attention_mask"] == 0] = -100
return data
def build_dataset():
en_dataset = load_dataset("uonlp/CulturaX", "en", split="train", streaming=True)
ja_dataset = load_dataset("uonlp/CulturaX", "ja", split="train", streaming=True)
zh_dataset = load_dataset("uonlp/CulturaX", "zh", split="train", streaming=True)
multilingual_dataset: IterableDataset = interleave_datasets(
[en_dataset, ja_dataset, zh_dataset], probabilities=[0.4, 0.3, 0.3], seed=42
)
# DDP
if is_initialized():
multilingual_dataset = split_dataset_by_node(
multilingual_dataset,
rank=get_rank(),
num_replicas=get_world_size(),
)
multilingual_dataset = multilingual_dataset.shuffle(seed=42, buffer_size=10000)
multilingual_dataset = multilingual_dataset.map(
encode, batched=True, remove_columns=multilingual_dataset.column_names
)
return multilingual_dataset
if __name__ == "__main__":
dataset = build_dataset()
print(list(dataset.take(16)))
from transformers import LlamaModel, LlamaConfig, AutoTokenizer
# reuse the tokenizer from the llama
model_type = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_type)
# new tokens
new_tokens = [f"<semantic_{i}>" for i in range(4096)]
tokenizer.add_tokens(new_tokens + ["<pad>"])
# pad token
tokenizer.pad_token = "<pad>"
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
print(f"Vocab size: {len(tokenizer)}")
hidden_size = 1024
intermediate_size = hidden_size * (11 / 3)
# then round to the nearest multiple of 8
intermediate_size = round(intermediate_size / 8) * 8
print(f"Hidden size: {hidden_size}")
print(f"Intermediate size: {intermediate_size}")
model = LlamaModel(
LlamaConfig(
vocab_size=tokenizer.vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=20,
num_attention_heads=16,
max_position_embeddings=4096,
)
)
model = model.bfloat16()
# Resize the token embeddings to include the new tokens
# Make sure it's a multiple of 8 for faster training
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6:.2f}M")
# Try tokenizing a new sequence
sequence = "Test <semantic_0> <semantic_1023> <pad>"
encoded = tokenizer.encode(sequence)
print("Test encoding....")
print(f"\tSentence: {sequence}")
print(f"\tEncoded: {encoded}")
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
# model.save_pretrained("./checkpoints/speech-lm-300m-init")
# tokenizer.save_pretrained("./checkpoints/speech-lm-300m-init")
model.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")
tokenizer.push_to_hub("fishaudio/speech-lm-300m", private=True, revision="init")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment