Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
10ef6f6c
Commit
10ef6f6c
authored
Apr 08, 2024
by
Yoach Lacombe
Browse files
improve torch compile logic
parent
82cbc3ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
19 deletions
+16
-19
run_stable_speech_training.py
run_stable_speech_training.py
+16
-19
No files found.
run_stable_speech_training.py
View file @
10ef6f6c
...
@@ -498,15 +498,7 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
...
@@ -498,15 +498,7 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
)
)
},
},
)
)
text_encode_per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
(
"TODO"
)
},
)
@
dataclass
@
dataclass
class
DataCollatorEncodecWithPadding
:
class
DataCollatorEncodecWithPadding
:
"""
"""
...
@@ -821,7 +813,7 @@ def main():
...
@@ -821,7 +813,7 @@ def main():
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
kwargs_handlers
=
[
InitProcessGroupKwargs
(
timeout
=
timedelta
(
minutes
=
60
))]
if
training_args
.
torch_compile
:
if
training_args
.
torch_compile
:
# TODO(YL): add more compile modes?
# TODO(YL): add more compile modes?
kwargs_handlers
.
append
(
TorchDynamoPlugin
(
backend
=
"inductor"
))
kwargs_handlers
.
append
(
TorchDynamoPlugin
(
backend
=
"inductor"
,
mode
=
"default"
))
#reduce-overhead
accelerator
=
Accelerator
(
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
...
@@ -1493,27 +1485,33 @@ def main():
...
@@ -1493,27 +1485,33 @@ def main():
# Define eval fn
# Define eval fn
def
eval_step
(
batch
,
accelerator
,
autocast_kwargs
,):
def
eval_step
(
batch
,
accelerator
,
autocast_kwargs
,):
model
.
eval
()
eval_model
=
model
if
not
training_args
.
torch_compile
else
model
.
_orig_mod
eval_model
.
eval
()
if
mixed_precision
==
"fp16"
:
if
mixed_precision
==
"fp16"
:
# fp16 doesn't work with T5-like models
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
with
torch
.
no_grad
():
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
if
training_args
.
parallel_mode
.
value
!=
"distributed"
or
training_args
.
torch_compile
:
else
:
encoder_outputs
=
eval_model
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
else
:
encoder_outputs
=
eval_model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
batch
[
"encoder_outputs"
]
=
encoder_outputs
batch
[
"encoder_outputs"
]
=
encoder_outputs
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
**
batch
)
outputs
=
eval_
model
(
**
batch
)
# CE (data) loss
# CE (data) loss
ce_loss
=
outputs
.
loss
ce_loss
=
outputs
.
loss
metrics
=
{
"loss"
:
ce_loss
}
metrics
=
{
"loss"
:
ce_loss
}
return
metrics
return
metrics
def
generate_step
(
batch
):
def
generate_step
(
batch
):
model
.
eval
()
batch
.
pop
(
"decoder_attention_mask"
,
None
)
batch
.
pop
(
"decoder_attention_mask"
,
None
)
output_audios
=
accelerator
.
unwrap_model
(
model
,
keep_fp32_wrapper
=
mixed_precision
!=
"fp16"
).
generate
(
**
batch
,
**
gen_kwargs
)
eval_model
=
accelerator
.
unwrap_model
(
model
,
keep_fp32_wrapper
=
mixed_precision
!=
"fp16"
).
eval
()
if
training_args
.
torch_compile
:
eval_model
=
model
.
_orig_mod
output_audios
=
eval_model
.
generate
(
**
batch
,
**
gen_kwargs
)
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
return
output_audios
return
output_audios
...
@@ -1593,7 +1591,6 @@ def main():
...
@@ -1593,7 +1591,6 @@ def main():
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
train_time
+=
time
.
time
()
-
train_start
train_time
+=
time
.
time
()
-
train_start
model
.
eval
()
# ======================== Evaluating ==============================
# ======================== Evaluating ==============================
eval_metrics
=
[]
eval_metrics
=
[]
eval_preds
=
[]
eval_preds
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment