Commit 5b5167d8 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

run sweep

parent b1fb7844
...@@ -7,10 +7,11 @@ command: ...@@ -7,10 +7,11 @@ command:
- --do_eval - --do_eval
- --trust_remote_code - --trust_remote_code
- --overwrite_output_dir - --overwrite_output_dir
- --ignore_mismatched_sizes
- ${args} - ${args}
method: grid method: grid
metric: metric:
goal: minimize goal: maximize
name: eval/accuracy name: eval/accuracy
parameters: parameters:
model_name_or_path: model_name_or_path:
...@@ -40,14 +41,18 @@ parameters: ...@@ -40,14 +41,18 @@ parameters:
value: false value: false
learning_rate: learning_rate:
value: 1e-4 value: 1e-4
lr_scheduler_type:
value: constant_with_warmup
max_length_seconds: max_length_seconds:
value: 20 value: 20
min_length_seconds:
value: 5
attention_mask: attention_mask:
value: false value: false
warmup_ratio: warmup_steps:
value: 0.1 value: 50
num_train_epochs: max_steps:
value: 5 value: 1000
per_device_train_batch_size: per_device_train_batch_size:
value: 32 value: 32
per_device_eval_batch_size: per_device_eval_batch_size:
...@@ -55,19 +60,21 @@ parameters: ...@@ -55,19 +60,21 @@ parameters:
preprocessing_num_workers: preprocessing_num_workers:
value: 16 value: 16
dataloader_num_workers: dataloader_num_workers:
value: 4 value: 8
logging_strategy: logging_strategy:
value: steps value: steps
logging_steps: logging_steps:
value: 10 value: 10
evaluation_strategy: evaluation_strategy:
value: epoch value: steps
eval_steps:
value: 1000
save_strategy: save_strategy:
value: epoch value: steps
save_steps:
value: 1000
metric_for_best_model: metric_for_best_model:
value: accuracy value: accuracy
save_total_limit:
value: 3
freeze_base_model: freeze_base_model:
values: values:
- true - true
......
...@@ -35,7 +35,7 @@ from transformers import ( ...@@ -35,7 +35,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
set_seed, WhisperForAudioClassification, set_seed,
) )
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
...@@ -165,7 +165,11 @@ class DataTrainingArguments: ...@@ -165,7 +165,11 @@ class DataTrainingArguments:
) )
max_length_seconds: float = field( max_length_seconds: float = field(
default=20, default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."}, metadata={"help": "Audio samples will be randomly cut to this length during training if the value is set."},
)
min_length_seconds: float = field(
default=5,
metadata={"help": "Audio samples less than this value will be filtered during training if the value is set."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, default=None,
...@@ -197,7 +201,10 @@ class ModelArguments: ...@@ -197,7 +201,10 @@ class ModelArguments:
default=None, metadata={"help": "Name or path of preprocessor config."} default=None, metadata={"help": "Name or path of preprocessor config."}
) )
freeze_feature_encoder: bool = field( freeze_feature_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."} default=False,
metadata={
"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
},
) )
freeze_base_model: bool = field( freeze_base_model: bool = field(
default=True, metadata={"help": "Whether to freeze the base encoder of the model."} default=True, metadata={"help": "Whether to freeze the base encoder of the model."}
...@@ -225,7 +232,7 @@ class ModelArguments: ...@@ -225,7 +232,7 @@ class ModelArguments:
}, },
) )
ignore_mismatched_sizes: bool = field( ignore_mismatched_sizes: bool = field(
default=False, default=True,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
) )
...@@ -535,6 +542,19 @@ def main(): ...@@ -535,6 +542,19 @@ def main():
desc="Filtering by labels", desc="Filtering by labels",
) )
# filter training data with inputs < min_input_length
min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(audio):
return len(audio["array"]) > min_input_length
raw_datasets = raw_datasets.filter(
is_audio_valid,
input_columns=["audio"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# Prepare label mappings # Prepare label mappings
raw_datasets = raw_datasets.map( raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)}, lambda label: {"labels": preprocess_labels(label)},
...@@ -631,7 +651,8 @@ def main(): ...@@ -631,7 +651,8 @@ def main():
if hasattr(model, "freeze_base_model"): if hasattr(model, "freeze_base_model"):
# wav2vec2-style models # wav2vec2-style models
model.freeze_base_model() model.freeze_base_model()
model.freeze_feature_encoder() if hasattr(model, "freeze_feature_encoder"):
model.freeze_feature_encoder()
elif hasattr(model, "freeze_encoder"): elif hasattr(model, "freeze_encoder"):
# whisper-style models # whisper-style models
model.freeze_encoder() model.freeze_encoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment