Commit 5b5167d8 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

run sweep

parent b1fb7844
......@@ -7,10 +7,11 @@ command:
- --do_eval
- --trust_remote_code
- --overwrite_output_dir
- --ignore_mismatched_sizes
- ${args}
method: grid
metric:
goal: minimize
goal: maximize
name: eval/accuracy
parameters:
model_name_or_path:
......@@ -40,14 +41,18 @@ parameters:
value: false
learning_rate:
value: 1e-4
lr_scheduler_type:
value: constant_with_warmup
max_length_seconds:
value: 20
min_length_seconds:
value: 5
attention_mask:
value: false
warmup_ratio:
value: 0.1
num_train_epochs:
value: 5
warmup_steps:
value: 50
max_steps:
value: 1000
per_device_train_batch_size:
value: 32
per_device_eval_batch_size:
......@@ -55,19 +60,21 @@ parameters:
preprocessing_num_workers:
value: 16
dataloader_num_workers:
value: 4
value: 8
logging_strategy:
value: steps
logging_steps:
value: 10
evaluation_strategy:
value: epoch
value: steps
eval_steps:
value: 1000
save_strategy:
value: epoch
value: steps
save_steps:
value: 1000
metric_for_best_model:
value: accuracy
save_total_limit:
value: 3
freeze_base_model:
values:
- true
......
......@@ -35,7 +35,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed, WhisperForAudioClassification,
set_seed,
)
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint
......@@ -165,7 +165,11 @@ class DataTrainingArguments:
)
max_length_seconds: float = field(
default=20,
metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
metadata={"help": "Audio samples will be randomly cut to this length during training if the value is set."},
)
min_length_seconds: float = field(
default=5,
metadata={"help": "Audio samples less than this value will be filtered during training if the value is set."},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
......@@ -197,7 +201,10 @@ class ModelArguments:
default=None, metadata={"help": "Name or path of preprocessor config."}
)
freeze_feature_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."}
default=False,
metadata={
"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."
},
)
freeze_base_model: bool = field(
default=True, metadata={"help": "Whether to freeze the base encoder of the model."}
......@@ -225,7 +232,7 @@ class ModelArguments:
},
)
ignore_mismatched_sizes: bool = field(
default=False,
default=True,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
......@@ -535,6 +542,19 @@ def main():
desc="Filtering by labels",
)
# filter training data with inputs < min_input_length
min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(audio):
return len(audio["array"]) > min_input_length
raw_datasets = raw_datasets.filter(
is_audio_valid,
input_columns=["audio"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# Prepare label mappings
raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)},
......@@ -631,7 +651,8 @@ def main():
if hasattr(model, "freeze_base_model"):
# wav2vec2-style models
model.freeze_base_model()
model.freeze_feature_encoder()
if hasattr(model, "freeze_feature_encoder"):
model.freeze_feature_encoder()
elif hasattr(model, "freeze_encoder"):
# whisper-style models
model.freeze_encoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment