Commit daca5721 authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

improve modeling code and logging

parent b7f5febc
......@@ -258,7 +258,7 @@ class ModelArguments:
metadata={"help": "Whether to do sampling or greedy decoding."},
)
max_length: int = field(
default=400, # TODO
default=1500, # TODO
metadata={"help": "Whether to do sampling or greedy decoding."},
)
bandwidth: float = field(
......@@ -741,7 +741,22 @@ def main():
project_dir=training_args.output_dir,
)
accelerator.init_trackers(project_name=data_args.wandb_project)
accelerator.init_trackers(project_name=data_args.wandb_project, config={
"learning_rate": training_args.learning_rate,
"model_name_or_path": model_args.model_name_or_path,
"num_train_epochs": training_args.num_train_epochs,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
"mixed_precision": mixed_precision,
"lr_scheduler_type":training_args.lr_scheduler_type,
"warmup_steps":training_args.warmup_steps,
"freeze_text_encoder":model_args.freeze_text_encoder,
"max_duration_in_seconds":data_args.max_duration_in_seconds,
"weight_decay": training_args.weight_decay,
"adam_beta1": training_args.adam_beta1,
"adam_beta2": training_args.adam_beta2,
})
# Detecting last checkpoint and eventually continue from last checkpoint
......@@ -1073,7 +1088,7 @@ def main():
input_columns=["input_ids", "prompt_input_ids"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=200,
writer_batch_size=100,
)
......@@ -1181,6 +1196,7 @@ def main():
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
)
# LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
......
......@@ -22,6 +22,11 @@ _deps = [
"transformers>=4.34.0",
"datasets[audio]>=2.14.5",
"torch",
"accelerate",
"evaluate",
"sentencepiece",
"descript-audio-codec",
"jiwer",
]
_extras_dev_deps = [
......
......@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModel
from transformers import AutoConfig, AutoModel, AutoModelForTextEncoding
from transformers.activations import ACT2FN
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
......@@ -1792,7 +1792,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
kwargs_text_encoder["config"] = encoder_config
text_encoder = AutoModel.from_pretrained(
text_encoder = AutoModelForTextEncoding.from_pretrained(
text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment