Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
80da6b4c
Commit
80da6b4c
authored
Mar 14, 2024
by
yoach@huggingface.co
Browse files
fix eval when fp16 + remove useless code
parent
e51113f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
19 deletions
+8
-19
run_stable_speech_training.py
run_stable_speech_training.py
+8
-19
No files found.
run_stable_speech_training.py
View file @
80da6b4c
...
...
@@ -192,22 +192,6 @@ def log_pred(
]},
step
=
step
)
#### ARGUMENTS
class
StableSpeechTrainer
(
Seq2SeqTrainer
):
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
):
if
self
.
model
.
config
.
pad_token_id
is
not
None
:
pad_token_id
=
self
.
model
.
config
.
pad_token_id
else
:
raise
ValueError
(
"Pad_token_id must be set in the configuration of the model, in order to pad tensors"
)
padded_tensor
=
pad_token_id
*
torch
.
ones
(
(
tensor
.
shape
[
0
],
max_length
,
tensor
.
shape
[
2
]),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
padded_tensor
[:,
:
tensor
.
shape
[
1
]]
=
tensor
return
padded_tensor
@
dataclass
class
ModelArguments
:
"""
...
...
@@ -1349,8 +1333,13 @@ def main():
return
ce_loss
,
metrics
# Define eval fn
def
eval_step
(
batch
):
def
eval_step
(
batch
,
accelerator
,
autocast_kwargs
,
):
model
.
eval
()
if
mixed_precision
==
"fp16"
:
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
batch
[
"encoder_outputs"
]
=
encoder_outputs
with
torch
.
no_grad
():
outputs
=
model
(
**
batch
)
...
...
@@ -1361,7 +1350,7 @@ def main():
def
generate_step
(
batch
):
model
.
eval
()
output_audios
=
accelerator
.
unwrap_model
(
model
).
generate
(
**
batch
,
**
gen_kwargs
)
output_audios
=
accelerator
.
unwrap_model
(
model
,
keep_fp32_wrapper
=
mixed_precision
!=
"fp16"
).
generate
(
**
batch
,
**
gen_kwargs
)
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
return
output_audios
...
...
@@ -1470,7 +1459,7 @@ def main():
disable
=
not
accelerator
.
is_local_main_process
,
):
# Model forward
eval_metric
=
eval_step
(
batch
)
eval_metric
=
eval_step
(
batch
,
accelerator
,
autocast_kwargs
)
eval_metric
=
accelerator
.
gather_for_metrics
(
eval_metric
)
eval_metrics
.
append
(
eval_metric
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment