Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zjsun
fish-speech
Commits
94a54a14
Commit
94a54a14
authored
Oct 10, 2023
by
Lengyue
Committed by
zjsun
Sep 03, 2025
Browse files
Format code
parent
b6417459
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
17 deletions
+40
-17
fine-tune.py
fine-tune.py
+12
-5
preparing_data/to_flac.py
preparing_data/to_flac.py
+19
-5
preparing_data/wenet_clean/clean_wenet_speech.py
preparing_data/wenet_clean/clean_wenet_speech.py
+9
-7
No files found.
fine-tune.py
View file @
94a54a14
...
...
@@ -38,8 +38,14 @@ class TrainingArguments(_TrainingArguments):
use_lora
:
bool
=
field
(
default
=
False
)
def
dataset_transform
(
batch
,
tokenizer
:
AutoTokenizer
=
None
):
outputs
=
tokenizer
(
batch
[
"prompt"
],
padding
=
"longest"
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"pt"
)
def
dataset_transform
(
batch
,
tokenizer
:
AutoTokenizer
=
None
):
outputs
=
tokenizer
(
batch
[
"prompt"
],
padding
=
"longest"
,
truncation
=
True
,
max_length
=
512
,
return_tensors
=
"pt"
,
)
labels
=
outputs
.
input_ids
.
clone
()
# Set the labels to -100 so that the logits are not affected by loss
...
...
@@ -51,6 +57,7 @@ def dataset_transform(batch, tokenizer: AutoTokenizer=None):
"labels"
:
labels
,
}
def
train
():
parser
=
HfArgumentParser
((
ModelArguments
,
DataArguments
,
TrainingArguments
))
model_args
,
data_args
,
training_args
=
parser
.
parse_args_into_dataclasses
()
...
...
@@ -87,11 +94,11 @@ def train():
try
:
dataset
=
load_from_disk
(
data_args
.
data_path
)
if
'
train
'
in
dataset
:
dataset
=
dataset
[
'
train
'
]
if
"
train
"
in
dataset
:
dataset
=
dataset
[
"
train
"
]
except
:
dataset
=
load_dataset
(
data_args
.
data_path
,
split
=
"train"
)
dataset
.
set_transform
(
partial
(
dataset_transform
,
tokenizer
=
tokenizer
))
dataset
=
dataset
.
train_test_split
(
test_size
=
1000
,
seed
=
42
)
...
...
preparing_data/to_flac.py
View file @
94a54a14
from
pathlib
import
Path
import
random
import
subprocess
from
multiprocessing
import
Pool
,
cpu_count
from
pathlib
import
Path
from
tqdm
import
tqdm
import
random
def
convert_to_flac
(
src_file_path
):
dst_file_path
=
src_file_path
.
with_suffix
(
".flac"
)
...
...
@@ -10,7 +12,17 @@ def convert_to_flac(src_file_path):
try
:
subprocess
.
check_call
(
[
"ffmpeg"
,
"-y"
,
"-i"
,
str
(
src_file_path
),
"-acodec"
,
"flac"
,
"-threads"
,
"0"
,
str
(
dst_file_path
)],
[
"ffmpeg"
,
"-y"
,
"-i"
,
str
(
src_file_path
),
"-acodec"
,
"flac"
,
"-threads"
,
"0"
,
str
(
dst_file_path
),
],
stdout
=
subprocess
.
DEVNULL
,
stderr
=
subprocess
.
DEVNULL
,
)
...
...
@@ -33,13 +45,15 @@ if __name__ == "__main__":
fail_counter
=
0
with
Pool
(
processes
=
cpu_count
(),
maxtasksperchild
=
100
)
as
pool
:
with
tqdm
(
pool
.
imap_unordered
(
convert_to_flac
,
wav_files
),
total
=
len
(
wav_files
))
as
pbar
:
with
tqdm
(
pool
.
imap_unordered
(
convert_to_flac
,
wav_files
),
total
=
len
(
wav_files
)
)
as
pbar
:
for
success
in
pbar
:
if
success
:
success_counter
+=
1
else
:
fail_counter
+=
1
pbar
.
set_description
(
f
"Success:
{
success_counter
}
, Fail:
{
fail_counter
}
"
)
print
(
f
"Successfully converted:
{
success_counter
}
"
)
...
...
preparing_data/wenet_clean/clean_wenet_speech.py
View file @
94a54a14
import
json
from
pathlib
import
Path
import
os
import
subprocess
import
tempfile
import
time
from
pathlib
import
Path
import
librosa
import
soundfile
as
sf
import
torch
import
torchaudio
from
fish_audio_preprocess.utils.separate_audio
import
(
separate_audio
,
merge_tracks
,
init_model
,
merge_tracks
,
separate_audio
,
)
from
tqdm
import
tqdm
import
time
import
os
import
tempfile
rank
=
int
(
os
.
environ
.
get
(
"SLURM_PROCID"
,
0
))
world_size
=
int
(
os
.
environ
.
get
(
"SLURM_NTASKS"
,
1
))
...
...
@@ -75,7 +75,9 @@ def main():
)
# Make it 2 channels
audio
=
torch
.
cat
([
audio
,
audio
],
dim
=
0
)
tracks
=
separate_audio
(
demucs
,
audio
,
shifts
=
1
,
num_workers
=
0
,
progress
=
False
)
tracks
=
separate_audio
(
demucs
,
audio
,
shifts
=
1
,
num_workers
=
0
,
progress
=
False
)
audio
=
merge_tracks
(
tracks
,
filter
=
[
"vocals"
])[
0
]
vocals
,
sr
=
(
torchaudio
.
functional
.
resample
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment