Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
03611f97
"examples/research_projects/colossalai/requirement.txt" did not exist on "c2283310688ff75e8fb4be3d9938ed0818cb038d"
Commit
03611f97
authored
Mar 05, 2024
by
Yoach Lacombe
Browse files
optimize GPU memory usage
parent
d112db94
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
run_stable_speech_training.py
run_stable_speech_training.py
+7
-4
No files found.
run_stable_speech_training.py
View file @
03611f97
...
@@ -70,15 +70,13 @@ AutoModel.register(DACConfig, DACModel)
...
@@ -70,15 +70,13 @@ AutoModel.register(DACConfig, DACModel)
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
from
accelerate.utils
import
set_seed
from
accelerate.utils.memory
import
release_memory
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
if
is_wandb_available
():
if
is_wandb_available
():
from
wandb
import
Audio
from
wandb
import
Audio
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version
(
"4.38.0.dev0"
)
check_min_version
(
"4.38.0.dev0"
)
...
@@ -1122,11 +1120,13 @@ def main():
...
@@ -1122,11 +1120,13 @@ def main():
cosine_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
audio_features
,
text_features
,
dim
=
1
,
eps
=
1e-8
)
cosine_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
audio_features
,
text_features
,
dim
=
1
,
eps
=
1e-8
)
clap
.
to
(
"cpu"
)
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
()
return
cosine_sim
.
mean
()
def
wer
(
prompts
,
audios
,
device
):
def
wer
(
prompts
,
audios
,
device
):
asr_pipeline
=
pipeline
(
model
=
"distil-whisper/distil-large-v2"
,
device
=
device
)
asr_pipeline
=
pipeline
(
model
=
"distil-whisper/distil-large-v2"
,
device
=
device
)
transcriptions
=
asr_pipeline
([{
'raw'
:
audio
,
'sampling_rate'
:
sampling_rate
}
for
audio
in
audios
])
transcriptions
=
asr_pipeline
([{
'raw'
:
audio
,
'sampling_rate'
:
sampling_rate
}
for
audio
in
audios
]
,
batch_size
=
int
(
training_args
.
per_device_eval_batch_size
)
)
word_error
=
100
*
metric
.
compute
(
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
])
word_error
=
100
*
metric
.
compute
(
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
])
...
@@ -1418,6 +1418,9 @@ def main():
...
@@ -1418,6 +1418,9 @@ def main():
eval_descriptions
=
[]
eval_descriptions
=
[]
eval_prompts
=
[]
eval_prompts
=
[]
eval_start
=
time
.
time
()
eval_start
=
time
.
time
()
# release training input batch
batch
=
release_memory
(
batch
)
validation_dataloader
=
DataLoader
(
validation_dataloader
=
DataLoader
(
vectorized_datasets
[
"eval"
],
vectorized_datasets
[
"eval"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment