Commit 94a54a14 authored by Lengyue's avatar Lengyue Committed by zjsun
Browse files

Format code

parent b6417459
......@@ -38,8 +38,14 @@ class TrainingArguments(_TrainingArguments):
use_lora: bool = field(default=False)
def dataset_transform(batch, tokenizer: AutoTokenizer=None):
outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt")
def dataset_transform(batch, tokenizer: AutoTokenizer = None):
outputs = tokenizer(
batch["prompt"],
padding="longest",
truncation=True,
max_length=512,
return_tensors="pt",
)
labels = outputs.input_ids.clone()
# Set the labels to -100 so that the logits are not affected by loss
......@@ -51,6 +57,7 @@ def dataset_transform(batch, tokenizer: AutoTokenizer=None):
"labels": labels,
}
def train():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
......@@ -87,11 +94,11 @@ def train():
try:
dataset = load_from_disk(data_args.data_path)
if 'train' in dataset:
dataset = dataset['train']
if "train" in dataset:
dataset = dataset["train"]
except:
dataset = load_dataset(data_args.data_path, split="train")
dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
dataset = dataset.train_test_split(test_size=1000, seed=42)
......
from pathlib import Path
import random
import subprocess
from multiprocessing import Pool, cpu_count
from pathlib import Path
from tqdm import tqdm
import random
def convert_to_flac(src_file_path):
dst_file_path = src_file_path.with_suffix(".flac")
......@@ -10,7 +12,17 @@ def convert_to_flac(src_file_path):
try:
subprocess.check_call(
["ffmpeg", "-y", "-i", str(src_file_path), "-acodec", "flac", "-threads", "0", str(dst_file_path)],
[
"ffmpeg",
"-y",
"-i",
str(src_file_path),
"-acodec",
"flac",
"-threads",
"0",
str(dst_file_path),
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
......@@ -33,13 +45,15 @@ if __name__ == "__main__":
fail_counter = 0
with Pool(processes=cpu_count(), maxtasksperchild=100) as pool:
with tqdm(pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)) as pbar:
with tqdm(
pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)
) as pbar:
for success in pbar:
if success:
success_counter += 1
else:
fail_counter += 1
pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}")
print(f"Successfully converted: {success_counter}")
......
import json
from pathlib import Path
import os
import subprocess
import tempfile
import time
from pathlib import Path
import librosa
import soundfile as sf
import torch
import torchaudio
from fish_audio_preprocess.utils.separate_audio import (
separate_audio,
merge_tracks,
init_model,
merge_tracks,
separate_audio,
)
from tqdm import tqdm
import time
import os
import tempfile
rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
......@@ -75,7 +75,9 @@ def main():
)
# Make it 2 channels
audio = torch.cat([audio, audio], dim=0)
tracks = separate_audio(demucs, audio, shifts=1, num_workers=0, progress=False)
tracks = separate_audio(
demucs, audio, shifts=1, num_workers=0, progress=False
)
audio = merge_tracks(tracks, filter=["vocals"])[0]
vocals, sr = (
torchaudio.functional.resample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment