Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zjsun
fish-speech
Commits
85f0282a
Commit
85f0282a
authored
Oct 11, 2023
by
Lengyue
Committed by
zjsun
Sep 03, 2025
Browse files
init training code
parent
66ea8ff4
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
864 additions
and
3 deletions
+864
-3
.gitignore
.gitignore
+2
-0
.project-root
.project-root
+1
-0
pdm.lock
pdm.lock
+740
-1
preparing_data/prepare_dataset.py
preparing_data/prepare_dataset.py
+88
-0
pyproject.toml
pyproject.toml
+4
-2
speech_lm/configs/pretrain.yaml
speech_lm/configs/pretrain.yaml
+6
-0
speech_lm/train.py
speech_lm/train.py
+23
-0
No files found.
.gitignore
View file @
85f0282a
.pgx.*
.pdm-python
/speech_lm.egg-info
__pycache__
/results
.project-root
0 → 100644
View file @
85f0282a
ROOT
pdm.lock
View file @
85f0282a
This diff is collapsed.
Click to expand it.
preparing_data/prepare_dataset.py
0 → 100644
View file @
85f0282a
import
json
import
os
from
pathlib
import
Path
import
librosa
import
torch
from
datasets
import
Dataset
from
multiprocess
import
set_start_method
from
transformers
import
AutoProcessor
,
EncodecModel
set_start_method
(
"spawn"
,
force
=
True
)
encodec_name
=
"facebook/encodec_24khz"
encodec_processor
=
AutoProcessor
.
from_pretrained
(
encodec_name
)
encodec_model
=
EncodecModel
.
from_pretrained
(
encodec_name
)
encodec_model
.
eval
()
def
tokenize
(
text
,
audio
,
sr
=
None
,
speaker
=
None
):
assert
sr
is
None
or
sr
==
encodec_processor
.
sampling_rate
if
isinstance
(
audio
,
(
str
,
Path
)):
audio
,
sr
=
librosa
.
load
(
audio
,
sr
=
sr
,
mono
=
True
)
prompt
=
"[INST] "
if
speaker
:
prompt
+=
f
"[SPK]
{
speaker
}
[/SPK] "
prompt
+=
f
"
{
text
}
[/INST] "
inputs
=
encodec_processor
(
raw_audio
=
audio
,
sampling_rate
=
sr
,
return_tensors
=
"pt"
).
to
(
encodec_model
.
device
)
outputs
=
encodec_model
.
encode
(
inputs
[
"input_values"
],
inputs
[
"padding_mask"
],
bandwidth
=
1.5
,
return_dict
=
True
)
assert
outputs
.
audio_codes
.
dim
()
==
4
# [batch, channel, codebook, code]
assert
outputs
.
audio_codes
.
shape
[
0
]
==
outputs
.
audio_codes
.
shape
[
1
]
==
1
codes
=
outputs
.
audio_codes
[
0
,
0
,
0
,
:].
long
()
codes_str
=
" "
.
join
([
f
"<encodec_
{
int
(
c
)
}
>"
for
c
in
codes
.
tolist
()])
prompt
+=
codes_str
return
{
"prompt"
:
prompt
,
"codes"
:
codes
,
}
def
wrap_tokenize
(
x
):
device
=
torch
.
device
(
"cuda"
,
0
)
if
encodec_model
.
device
!=
device
:
encodec_model
.
to
(
device
)
return
tokenize
(
text
=
x
[
"text"
],
audio
=
x
[
"raw_path"
],
sr
=
encodec_processor
.
sampling_rate
,
speaker
=
x
[
"speaker"
],
)
def
generator_libritts_r
():
base
=
Path
(
"dataset/tts/LibriTTS_R"
)
for
i
in
base
.
rglob
(
"*.wav"
):
text_file
=
i
.
with_suffix
(
".normalized.txt"
)
if
not
text_file
.
exists
():
continue
text
=
text_file
.
read_text
().
strip
()
yield
{
"text"
:
text
,
"speaker"
:
f
"libritts_
{
i
.
parent
.
parent
.
name
}
"
,
"raw_path"
:
str
(
i
),
"path"
:
str
(
i
.
relative_to
(
base
)),
}
if
__name__
==
"__main__"
:
dataset
=
Dataset
.
from_generator
(
generator_libritts_r
)
dataset
=
dataset
.
map
(
wrap_tokenize
,
num_proc
=
12
)
dataset
=
dataset
.
remove_columns
([
"raw_path"
])
dataset
.
save_to_disk
(
"dataset/tts/libritts-r-encodec"
)
dataset
.
push_to_hub
(
"fishaudio/libritts-r-encodec"
,
private
=
True
)
pyproject.toml
View file @
85f0282a
...
...
@@ -12,10 +12,12 @@ dependencies = [
"torchaudio>=2.1.0"
,
"transformers>=4.34.0"
,
"datasets>=2.14.5"
,
"accelerate>=0.23.0"
,
"bitsandbytes>=0.41.1"
,
"peft>=0.5.0"
,
"omegaconf>=2.3.0"
,
"deepspeed>=0.11.1"
,
"lightning>=2.0.9.post0"
,
"hydra-core>=1.3.2"
,
"pyrootutils>=1.0.4"
,
]
requires-python
=
">=3.10"
license
=
{
text
=
"MIT"
}
...
...
speech_lm/configs/pretrain.yaml
0 → 100644
View file @
85f0282a
paths
:
run_dir
:
results/pretrain
hydra
:
run
:
dir
:
${paths.run_dir}
speech_lm/train.py
0 → 100644
View file @
85f0282a
import
torch
from
lightning.fabric
import
Fabric
import
hydra
from
omegaconf
import
DictConfig
,
OmegaConf
import
pyrootutils
# Allow TF32 on Ampere GPUs
torch
.
set_float32_matmul_precision
(
"high"
)
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# register eval resolver and root
pyrootutils
.
setup_root
(
__file__
,
indicator
=
".project-root"
,
pythonpath
=
True
)
OmegaConf
.
register_new_resolver
(
"eval"
,
eval
)
# flake8: noqa: E402
from
speech_lm.dataset
import
build_dataset
@
hydra
.
main
(
version_base
=
"1.3"
,
config_path
=
"./configs"
,
config_name
=
"pretrain.yaml"
)
def
main
(
cfg
:
DictConfig
):
print
(
cfg
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment