Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
e51113f9
Commit
e51113f9
authored
Mar 13, 2024
by
Yoach Lacombe
Browse files
fix fp16 training and attention mask in generation
parent
c6b4674d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
17 deletions
+36
-17
run_stable_speech_training.py
run_stable_speech_training.py
+21
-6
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+15
-11
No files found.
run_stable_speech_training.py
View file @
e51113f9
...
...
@@ -69,7 +69,7 @@ AutoModel.register(DACConfig, DACModel)
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
from
accelerate.utils
import
set_seed
,
AutocastKwargs
from
accelerate.utils.memory
import
release_memory
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
...
...
@@ -452,6 +452,14 @@ class DataTrainingArguments:
"help"
:
"If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
pad_to_multiple_of
:
Optional
[
int
]
=
field
(
default
=
2
,
metadata
=
{
"help"
:
(
"Pad to multiple of for tokenizers."
)
},
)
@
dataclass
class
StableSpeechTrainingArguments
(
Seq2SeqTrainingArguments
):
...
...
@@ -546,7 +554,6 @@ class DataCollatorStableSpeechWithPadding:
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
# TODO: check it's been padded on the left
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
if
"attention_mask"
in
prompt_input_ids
:
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
...
...
@@ -1214,7 +1221,7 @@ def main():
# Instantiate custom data collator
data_collator
=
DataCollatorStableSpeechWithPadding
(
audio_feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
audio_feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
)
# Freeze Encoders
...
...
@@ -1318,13 +1325,21 @@ def main():
"temperature"
:
model_args
.
temperature
,
"max_length"
:
model_args
.
max_length
,
}
# TODO: add max_length
# Define gradient update step fn
def
train_step
(
batch
,
accelerator
,
autocast_kwargs
,
):
model
.
train
()
if
mixed_precision
==
"fp16"
:
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
batch
[
"encoder_outputs"
]
=
encoder_outputs
outputs
=
model
(
**
batch
)
# CE (data) loss
ce_loss
=
outputs
.
loss
...
...
@@ -1350,7 +1365,7 @@ def main():
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
return
output_audios
autocast_kwargs
=
AutocastKwargs
(
enabled
=
False
)
for
epoch
in
range
(
epochs_trained
,
num_epochs
):
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
# TODO: add args
...
...
@@ -1374,7 +1389,7 @@ def main():
for
batch
in
train_dataloader
:
with
accelerator
.
accumulate
(
model
):
loss
,
train_metric
=
train_step
(
batch
)
loss
,
train_metric
=
train_step
(
batch
,
accelerator
,
autocast_kwargs
)
accelerator
.
backward
(
loss
)
if
accelerator
.
sync_gradients
:
accelerator
.
clip_grad_norm_
(
model
.
parameters
(),
training_args
.
max_grad_norm
)
...
...
stable_speech/modeling_stable_speech.py
View file @
e51113f9
...
...
@@ -804,15 +804,18 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if
prompt_hidden_states
is
not
None
:
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
# TODO: verify if prompt attention mask is required and has to be
# As it is, the
mask
ed
i
d
s
from the prompt will still count in the positions embeddings
if
prompt_attention_mask
is
not
None
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
elif
prompt_attention_mask
is
not
None
:
logger
.
warning_once
(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
# As it is, the masked ids from the prompt will still count in the positions embeddings
if
prompt_attention_
mask
is
not
None
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
elif
prompt_attention_mask
is
not
None
:
logger
.
warning_once
(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if
past_key_values
is
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)],
dim
=
1
)
else
:
generated_length
=
past_key_values_length
-
prompt_attention_mask
.
shape
[
1
]
+
1
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
((
input_shape
[
0
]
,
generated_length
),
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)],
dim
=
1
)
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
attention_mask
=
_prepare_4d_causal_attention_mask
(
...
...
@@ -1174,7 +1177,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if
prompt_attention_mask
is
not
None
:
prompt_attention_mask
=
torch
.
concatenate
(
prompt_attention_mask
,
torch
.
zeros_like
(
prompt_attention_mask
),
dim
=
0
[
prompt_attention_mask
,
torch
.
zeros_like
(
prompt_attention_mask
)
]
,
dim
=
0
)
if
past_key_values
is
not
None
:
...
...
@@ -2061,8 +2064,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
decoder_attention_mask
.
repeat
((
2
,
1
))
if
prompt_hidden_states
is
not
None
:
# TODO: ? we probably don't want to keep guidance scale here ? different task than musicgeneration
prompt_hidden_states
=
torch
.
concatenate
([
prompt_hidden_states
,
torch
.
zeros_like
(
prompt_hidden_states
)],
dim
=
0
)
prompt_hidden_states
=
prompt_hidden_states
.
repeat
((
2
,
1
,
1
))
if
prompt_attention_mask
is
not
None
:
prompt_attention_mask
=
prompt_attention_mask
.
repeat
((
2
,
1
))
if
past_key_values
is
not
None
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment