Commit 94a54a14 authored by Lengyue's avatar Lengyue Committed by zjsun
Browse files

Format code

parent b6417459
...@@ -38,8 +38,14 @@ class TrainingArguments(_TrainingArguments): ...@@ -38,8 +38,14 @@ class TrainingArguments(_TrainingArguments):
use_lora: bool = field(default=False) use_lora: bool = field(default=False)
def dataset_transform(batch, tokenizer: AutoTokenizer=None): def dataset_transform(batch, tokenizer: AutoTokenizer = None):
outputs = tokenizer(batch["prompt"], padding="longest", truncation=True, max_length=512, return_tensors="pt") outputs = tokenizer(
batch["prompt"],
padding="longest",
truncation=True,
max_length=512,
return_tensors="pt",
)
labels = outputs.input_ids.clone() labels = outputs.input_ids.clone()
# Set the labels to -100 so that the logits are not affected by loss # Set the labels to -100 so that the logits are not affected by loss
...@@ -51,6 +57,7 @@ def dataset_transform(batch, tokenizer: AutoTokenizer=None): ...@@ -51,6 +57,7 @@ def dataset_transform(batch, tokenizer: AutoTokenizer=None):
"labels": labels, "labels": labels,
} }
def train(): def train():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
...@@ -87,11 +94,11 @@ def train(): ...@@ -87,11 +94,11 @@ def train():
try: try:
dataset = load_from_disk(data_args.data_path) dataset = load_from_disk(data_args.data_path)
if 'train' in dataset: if "train" in dataset:
dataset = dataset['train'] dataset = dataset["train"]
except: except:
dataset = load_dataset(data_args.data_path, split="train") dataset = load_dataset(data_args.data_path, split="train")
dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer)) dataset.set_transform(partial(dataset_transform, tokenizer=tokenizer))
dataset = dataset.train_test_split(test_size=1000, seed=42) dataset = dataset.train_test_split(test_size=1000, seed=42)
......
from pathlib import Path import random
import subprocess import subprocess
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
import random
def convert_to_flac(src_file_path): def convert_to_flac(src_file_path):
dst_file_path = src_file_path.with_suffix(".flac") dst_file_path = src_file_path.with_suffix(".flac")
...@@ -10,7 +12,17 @@ def convert_to_flac(src_file_path): ...@@ -10,7 +12,17 @@ def convert_to_flac(src_file_path):
try: try:
subprocess.check_call( subprocess.check_call(
["ffmpeg", "-y", "-i", str(src_file_path), "-acodec", "flac", "-threads", "0", str(dst_file_path)], [
"ffmpeg",
"-y",
"-i",
str(src_file_path),
"-acodec",
"flac",
"-threads",
"0",
str(dst_file_path),
],
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
) )
...@@ -33,13 +45,15 @@ if __name__ == "__main__": ...@@ -33,13 +45,15 @@ if __name__ == "__main__":
fail_counter = 0 fail_counter = 0
with Pool(processes=cpu_count(), maxtasksperchild=100) as pool: with Pool(processes=cpu_count(), maxtasksperchild=100) as pool:
with tqdm(pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)) as pbar: with tqdm(
pool.imap_unordered(convert_to_flac, wav_files), total=len(wav_files)
) as pbar:
for success in pbar: for success in pbar:
if success: if success:
success_counter += 1 success_counter += 1
else: else:
fail_counter += 1 fail_counter += 1
pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}") pbar.set_description(f"Success: {success_counter}, Fail: {fail_counter}")
print(f"Successfully converted: {success_counter}") print(f"Successfully converted: {success_counter}")
......
import json import json
from pathlib import Path import os
import subprocess import subprocess
import tempfile
import time
from pathlib import Path
import librosa import librosa
import soundfile as sf import soundfile as sf
import torch import torch
import torchaudio import torchaudio
from fish_audio_preprocess.utils.separate_audio import ( from fish_audio_preprocess.utils.separate_audio import (
separate_audio,
merge_tracks,
init_model, init_model,
merge_tracks,
separate_audio,
) )
from tqdm import tqdm from tqdm import tqdm
import time
import os
import tempfile
rank = int(os.environ.get("SLURM_PROCID", 0)) rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1)) world_size = int(os.environ.get("SLURM_NTASKS", 1))
...@@ -75,7 +75,9 @@ def main(): ...@@ -75,7 +75,9 @@ def main():
) )
# Make it 2 channels # Make it 2 channels
audio = torch.cat([audio, audio], dim=0) audio = torch.cat([audio, audio], dim=0)
tracks = separate_audio(demucs, audio, shifts=1, num_workers=0, progress=False) tracks = separate_audio(
demucs, audio, shifts=1, num_workers=0, progress=False
)
audio = merge_tracks(tracks, filter=["vocals"])[0] audio = merge_tracks(tracks, filter=["vocals"])[0]
vocals, sr = ( vocals, sr = (
torchaudio.functional.resample( torchaudio.functional.resample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment