Commit 7ae1e8e3 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

remove trainer code

parent 9d25447e
...@@ -1455,141 +1455,6 @@ def main(): ...@@ -1455,141 +1455,6 @@ def main():
accelerator.end_training() accelerator.end_training()
###########################################################################
# Initialize StableSpeechTrainer
trainer = StableSpeechTrainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
tokenizer=feature_extractor,
)
if data_args.add_audio_samples_to_wandb and "wandb" in training_args.report_to and training_args.do_eval:
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
)
def decode_predictions(predictions):
audios = predictions.predictions
return {"audio": np.array(audios)}
class WandbPredictionProgressCallback(WandbCallback):
"""Custom WandbCallback to log model predictions during training.
"""
def __init__(self, trainer, val_dataset, description_tokenizer, # TODO: add
num_samples=8):
"""Initializes the WandbPredictionProgressCallback instance.
Args:
trainer (Seq2SeqTrainer): The Hugging Face Seq2SeqTrainer instance.
val_dataset (Dataset): The validation dataset.
num_samples (int, optional): Number of samples to select from
the validation dataset for generating predictions.
Defaults to 8.
"""
super().__init__()
self.trainer = trainer
self.description_tokenizer = description_tokenizer
self.sample_dataset = val_dataset.select(range(num_samples))
def on_evaluate(self, args, state, control, **kwargs):
super().on_evaluate(args, state, control, **kwargs)
predictions = self.trainer.predict(self.sample_dataset)
# decode predictions and labels
predictions = decode_predictions(predictions)
input_ids = self.sample_dataset["input_ids"]
texts = self.description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
audios = predictions["audio"]
# log the table to wandb
self._wandb.log({"sample_songs": [self._wandb.Audio(audio, caption=text, sample_rate=sampling_rate) for (audio, text) in zip(audios, texts)]})
# Instantiate the WandbPredictionProgressCallback
progress_callback = WandbPredictionProgressCallback(
trainer=trainer,
val_dataset=vectorized_datasets["eval"],
description_tokenizer=description_tokenizer,
num_samples=max_eval_samples,
)
# Add the callback to the trainer
trainer.add_callback(progress_callback)
# 8. Finally, we can start training
# Training
if training_args.do_train:
# use last checkpoint if exist
if last_checkpoint is not None:
checkpoint = last_checkpoint
# TODO: it's loading trainer from model_name_or_path doesn't work if saving config
# elif os.path.isdir(model_args.model_name_or_path):
# checkpoint = model_args.model_name_or_path
else:
checkpoint = None
train_result = trainer.train(resume_from_checkpoint=checkpoint, ignore_keys_for_eval=["past_key_values", "attentions"])
trainer.save_model()
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples
if data_args.max_train_samples is not None
else len(vectorized_datasets["train"])
)
metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
)
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Write model card and (optionally) push to hub
config_name = data_args.train_dataset_config_name if data_args.train_dataset_config_name is not None else "na"
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"tasks": "text-to-speech",
"tags": ["text-to-speech", data_args.train_dataset_name],
"dataset_args": (
f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
f" {data_args.eval_split_name}"
),
"dataset": f"{data_args.train_dataset_name.upper()} - {config_name.upper()}",
}
if training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(**kwargs)
return results
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment