Commit daca5721 authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

improve modeling code and logging

parent b7f5febc
...@@ -258,7 +258,7 @@ class ModelArguments: ...@@ -258,7 +258,7 @@ class ModelArguments:
metadata={"help": "Whether to do sampling or greedy decoding."}, metadata={"help": "Whether to do sampling or greedy decoding."},
) )
max_length: int = field( max_length: int = field(
default=400, # TODO default=1500, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."}, metadata={"help": "Whether to do sampling or greedy decoding."},
) )
bandwidth: float = field( bandwidth: float = field(
...@@ -741,7 +741,22 @@ def main(): ...@@ -741,7 +741,22 @@ def main():
project_dir=training_args.output_dir, project_dir=training_args.output_dir,
) )
accelerator.init_trackers(project_name=data_args.wandb_project) accelerator.init_trackers(project_name=data_args.wandb_project, config={
"learning_rate": training_args.learning_rate,
"model_name_or_path": model_args.model_name_or_path,
"num_train_epochs": training_args.num_train_epochs,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
"mixed_precision": mixed_precision,
"lr_scheduler_type":training_args.lr_scheduler_type,
"warmup_steps":training_args.warmup_steps,
"freeze_text_encoder":model_args.freeze_text_encoder,
"max_duration_in_seconds":data_args.max_duration_in_seconds,
"weight_decay": training_args.weight_decay,
"adam_beta1": training_args.adam_beta1,
"adam_beta2": training_args.adam_beta2,
})
# Detecting last checkpoint and eventually continue from last checkpoint # Detecting last checkpoint and eventually continue from last checkpoint
...@@ -1073,7 +1088,7 @@ def main(): ...@@ -1073,7 +1088,7 @@ def main():
input_columns=["input_ids", "prompt_input_ids"], input_columns=["input_ids", "prompt_input_ids"],
desc="Postprocessing labeling", desc="Postprocessing labeling",
with_indices=True, with_indices=True,
writer_batch_size=200, writer_batch_size=100,
) )
...@@ -1181,6 +1196,7 @@ def main(): ...@@ -1181,6 +1196,7 @@ def main():
lr=training_args.learning_rate, lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2), betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon, eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
) )
# LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
......
...@@ -22,6 +22,11 @@ _deps = [ ...@@ -22,6 +22,11 @@ _deps = [
"transformers>=4.34.0", "transformers>=4.34.0",
"datasets[audio]>=2.14.5", "datasets[audio]>=2.14.5",
"torch", "torch",
"accelerate",
"evaluate",
"sentencepiece",
"descript-audio-codec",
"jiwer",
] ]
_extras_dev_deps = [ _extras_dev_deps = [
......
...@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union ...@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
...@@ -1792,7 +1792,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1792,7 +1792,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
kwargs_text_encoder["config"] = encoder_config kwargs_text_encoder["config"] = encoder_config
text_encoder = AutoModel.from_pretrained( text_encoder = AutoModelForTextEncoding.from_pretrained(
text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment