Commit b1fb7844 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

sweep over models/freeze

parent 83953064
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
python run_audio_classification.py \ python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \ --train_dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "main+en_accented+default" \ --train_dataset_config_name "default+en_accented+default" \
--train_split_name "train+test+validation" \ --train_split_name "train+test+validation" \
--train_label_column_name "accent+accent+accent" \ --train_label_column_name "accent+accent+accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \ --eval_dataset_name "sanchit-gandhi/edacc" \
......
command:
- python3
- ${program}
- --load_best_model_at_end
- --fp16
- --do_train
- --do_eval
- --trust_remote_code
- --overwrite_output_dir
- ${args}
method: grid
metric:
goal: minimize
name: eval/accuracy
parameters:
model_name_or_path:
values:
- facebook/mms-lid-126
- openai/whisper-large-v3
- facebook/w2v-bert-2.0
train_dataset_name:
value: sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc
train_dataset_config_name:
value: default+en_accented+default
train_split_name:
value: train+test+validation
train_label_column_name:
value: accent+accent+accent
eval_dataset_name:
value: sanchit-gandhi/edacc
eval_dataset_config_name:
value: default
eval_split_name:
value: test
eval_label_column_name:
value: accent
output_dir:
value: ./
remove_unused_columns:
value: false
learning_rate:
value: 1e-4
max_length_seconds:
value: 20
attention_mask:
value: false
warmup_ratio:
value: 0.1
num_train_epochs:
value: 5
per_device_train_batch_size:
value: 32
per_device_eval_batch_size:
value: 32
preprocessing_num_workers:
value: 16
dataloader_num_workers:
value: 4
logging_strategy:
value: steps
logging_steps:
value: 10
evaluation_strategy:
value: epoch
save_strategy:
value: epoch
metric_for_best_model:
value: accuracy
save_total_limit:
value: 3
freeze_base_model:
values:
- true
- false
push_to_hub:
value: false
program: run_audio_classification.py
project: mms-lid-accent-classification
\ No newline at end of file
...@@ -197,7 +197,7 @@ class ModelArguments: ...@@ -197,7 +197,7 @@ class ModelArguments:
default=None, metadata={"help": "Name or path of preprocessor config."} default=None, metadata={"help": "Name or path of preprocessor config."}
) )
freeze_feature_encoder: bool = field( freeze_feature_encoder: bool = field(
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."} default=False, metadata={"help": "Whether to freeze the feature encoder layers of the model. Only relevant for Wav2Vec2-style models."}
) )
freeze_base_model: bool = field( freeze_base_model: bool = field(
default=True, metadata={"help": "Whether to freeze the base encoder of the model."} default=True, metadata={"help": "Whether to freeze the base encoder of the model."}
...@@ -297,6 +297,7 @@ def load_multiple_datasets( ...@@ -297,6 +297,7 @@ def load_multiple_datasets(
dataset_config_names: Union[List, str], dataset_config_names: Union[List, str],
splits: Optional[Union[List, str]] = None, splits: Optional[Union[List, str]] = None,
label_column_names: Optional[List] = None, label_column_names: Optional[List] = None,
sampling_rate: Optional[int] = 16000,
stopping_strategy: Optional[str] = "first_exhausted", stopping_strategy: Optional[str] = "first_exhausted",
dataset_samples: Optional[Union[List, np.array]] = None, dataset_samples: Optional[Union[List, np.array]] = None,
streaming: Optional[bool] = False, streaming: Optional[bool] = False,
...@@ -332,6 +333,8 @@ def load_multiple_datasets( ...@@ -332,6 +333,8 @@ def load_multiple_datasets(
f" '{dataset_dict['name']}'. Make sure to set `--audio_column_name` to" f" '{dataset_dict['name']}'. Make sure to set `--audio_column_name` to"
f" the correct audio column - one of {', '.join(dataset_features)}." f" the correct audio column - one of {', '.join(dataset_features)}."
) )
# resample to specified sampling rate
dataset = dataset.cast_column("audio", datasets.features.Audio(sampling_rate))
if dataset_dict["label_column_name"] not in dataset_features: if dataset_dict["label_column_name"] not in dataset_features:
raise ValueError( raise ValueError(
...@@ -617,16 +620,19 @@ def main(): ...@@ -617,16 +620,19 @@ def main():
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
# freeze the convolutional waveform encoder # freeze the convolutional waveform encoder for wav2vec2-style models
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
model.freeze_feature_encoder() if hasattr(model, "freeze_feature_encoder"):
model.freeze_feature_encoder()
else:
raise ValueError("Method for freezing the feature encoder is not defined for Whisper-style models.")
if model_args.freeze_base_model: if model_args.freeze_base_model:
if model.hasattr("freeze_base_model"): if hasattr(model, "freeze_base_model"):
# wav2vec2-style models # wav2vec2-style models
model.freeze_base_model() model.freeze_base_model()
model.freeze_feature_encoder() model.freeze_feature_encoder()
elif model.hasattr("freeze_encoder"): elif hasattr(model, "freeze_encoder"):
# whisper-style models # whisper-style models
model.freeze_encoder() model.freeze_encoder()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment