Commit e51113f9 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix fp16 training and attention mask in generation

parent c6b4674d
......@@ -69,7 +69,7 @@ AutoModel.register(DACConfig, DACModel)
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import set_seed, AutocastKwargs
from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
......@@ -452,6 +452,14 @@ class DataTrainingArguments:
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
pad_to_multiple_of: Optional[int] = field(
default=2,
metadata={
"help": (
"Pad to multiple of for tokenizers."
)
},
)
@dataclass
class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
......@@ -546,7 +554,6 @@ class DataCollatorStableSpeechWithPadding:
prompt_input_ids = [{"input_ids": feature["prompt_input_ids"]} for feature in features]
prompt_input_ids = self.prompt_tokenizer.pad(prompt_input_ids, return_tensors="pt", padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of)
# TODO: check it's been padded on the left
batch["prompt_input_ids"] = prompt_input_ids["input_ids"]
if "attention_mask" in prompt_input_ids:
batch["prompt_attention_mask"] = prompt_input_ids["attention_mask"]
......@@ -1214,7 +1221,7 @@ def main():
# Instantiate custom data collator
data_collator = DataCollatorStableSpeechWithPadding(
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer
audio_feature_extractor=feature_extractor, feature_extractor_input_name=feature_extractor_input_name, prompt_tokenizer=prompt_tokenizer, description_tokenizer=description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of
)
# Freeze Encoders
......@@ -1318,13 +1325,21 @@ def main():
"temperature": model_args.temperature,
"max_length": model_args.max_length,
}
# TODO: add max_length
# Define gradient update step fn
def train_step(
batch,
accelerator,
autocast_kwargs,
):
model.train()
if mixed_precision == "fp16":
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.module.text_encoder(input_ids= batch.get("input_ids"), attention_mask=batch.get("attention_mask", None))
batch["encoder_outputs"] = encoder_outputs
outputs = model(**batch)
# CE (data) loss
ce_loss = outputs.loss
......@@ -1350,7 +1365,7 @@ def main():
output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
return output_audios
autocast_kwargs = AutocastKwargs(enabled=False)
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO: add args
......@@ -1374,7 +1389,7 @@ def main():
for batch in train_dataloader:
with accelerator.accumulate(model):
loss, train_metric = train_step(batch)
loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
......
......@@ -804,15 +804,18 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
# TODO: verify if prompt attention mask is required and has to be
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
elif prompt_attention_mask is not None:
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
elif prompt_attention_mask is not None:
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
if past_key_values is None:
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)], dim=1)
else:
generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
attention_mask = torch.cat([prompt_attention_mask, torch.ones((input_shape[0] ,generated_length), device=self.device, dtype=prompt_attention_mask.dtype)], dim=1)
input_shape = inputs_embeds.size()[:-1]
attention_mask = _prepare_4d_causal_attention_mask(
......@@ -1174,7 +1177,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if prompt_attention_mask is not None:
prompt_attention_mask = torch.concatenate(
prompt_attention_mask, torch.zeros_like(prompt_attention_mask), dim=0
[prompt_attention_mask, torch.zeros_like(prompt_attention_mask)], dim=0
)
if past_key_values is not None:
......@@ -2061,8 +2064,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
# TODO: ? we probably don't want to keep guidance scale here ? different task than musicgeneration
prompt_hidden_states = torch.concatenate([prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0)
prompt_hidden_states = prompt_hidden_states.repeat((2,1,1))
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.repeat((2,1))
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment