Commit 64cfa64e authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

update audio classification script

parent 10ef6f6c
#!/usr/bin/env bash #!/usr/bin/env bash
python run_audio_classification.py \ CUDA_VISIBLE_DEVICES=2 python run_audio_classification_one_layer.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-4017" \
--train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \ --train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--train_dataset_config_name "default" \ --train_dataset_config_name "default" \
--train_split_name "train" \ --train_split_name "train" \
...@@ -10,11 +10,11 @@ python run_audio_classification.py \ ...@@ -10,11 +10,11 @@ python run_audio_classification.py \
--eval_dataset_config_name "default" \ --eval_dataset_config_name "default" \
--eval_split_name "test" \ --eval_split_name "test" \
--eval_label_column_name "labels" \ --eval_label_column_name "labels" \
--output_dir "./" \ --output_dir "./tmp/" \
--do_train \ --do_train \
--do_eval \ --do_eval \
--overwrite_output_dir \ --overwrite_output_dir \
--remove_unused_columns False \ --remove_unused_columns false \
--fp16 \ --fp16 \
--fp16_full_eval \ --fp16_full_eval \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
...@@ -30,9 +30,11 @@ python run_audio_classification.py \ ...@@ -30,9 +30,11 @@ python run_audio_classification.py \
--logging_strategy "steps" \ --logging_strategy "steps" \
--logging_steps 10 \ --logging_steps 10 \
--evaluation_strategy "steps" \ --evaluation_strategy "steps" \
--eval_steps 500 \ --eval_steps 300 \
--save_strategy "no" \ --save_strategy "no" \
--save_steps 2000 \ --save_steps 2000 \
--freeze_base_model True \ --freeze_base_model true \
--push_to_hub False \ --freeze_feature_encoder true \
--trust_remote_code --push_to_hub false \
--trust_remote_code \
--use_weighted_layer_sum true \
command: command:
- python3 - python3
- ${program} - ${program}
- --load_best_model_at_end
- --fp16 - --fp16
- --fp16_full_eval - --fp16_full_eval
- --do_train - --do_train
- --do_eval - --do_eval
- --trust_remote_code - --trust_remote_code
- --overwrite_output_dir - --overwrite_output_dir
- --ignore_mismatched_sizes
- --gradient_checkpointing
- ${args} - ${args}
method: grid method: random
metric: metric:
goal: maximize goal: maximize
name: eval/accuracy name: eval/accuracy
parameters: parameters:
model_name_or_path: model_name_or_path:
value: facebook/mms-lid-126 value: facebook/mms-lid-4017
train_dataset_name: train_dataset_name:
value: stable-speech/concatenated-accent-dataset value: "stable-speech/concatenated-normalized-accent-dataset+stable-speech/concatenated-common-voice-15-accented"
train_dataset_config_name: train_dataset_config_name:
value: default value: "default+default"
train_split_name: train_split_name:
value: train value: "train+train"
train_label_column_name: train_label_column_name:
value: labels value: "labels+labels"
eval_dataset_name: eval_dataset_name:
value: stable-speech/concatenated-accent-dataset value: stable-speech/concatenated-normalized-accent-dataset
eval_dataset_config_name: eval_dataset_config_name:
value: default value: default
eval_split_name: eval_split_name:
...@@ -35,31 +32,38 @@ parameters: ...@@ -35,31 +32,38 @@ parameters:
eval_label_column_name: eval_label_column_name:
value: labels value: labels
output_dir: output_dir:
value: ./ value: "/raid/yoach/tmp/"
remove_unused_columns: remove_unused_columns:
value: false value: false
learning_rate: learning_rate:
value: 1e-4 distribution: log_uniform_values
min: 3e-6
max: 0.01
lr_scheduler_type: lr_scheduler_type:
value: constant_with_warmup value: constant
max_length_seconds: max_length_seconds:
value: 20 # give some data diversity for longer audio samples value: 20 # give some data diversity for longer audio samples
min_length_seconds: min_length_seconds:
value: 7 value: 5
attention_mask: attention_mask:
value: true values:
warmup_steps: - true
value: 100 num_train_epochs:
max_steps: values:
value: 2000 - 2
- 5
- 10
- 20
- 40
- 60
per_device_train_batch_size: per_device_train_batch_size:
value: 32 value: 32
per_device_eval_batch_size: per_device_eval_batch_size:
value: 16 value: 32
preprocessing_num_workers: preprocessing_num_workers:
value: 4 value: 8
dataloader_num_workers: dataloader_num_workers:
value: 4 value: 8
logging_strategy: logging_strategy:
value: steps value: steps
logging_steps: logging_steps:
...@@ -67,20 +71,28 @@ parameters: ...@@ -67,20 +71,28 @@ parameters:
evaluation_strategy: evaluation_strategy:
value: steps value: steps
eval_steps: eval_steps:
value: 1000 value: 2000
save_strategy: save_strategy:
value: steps value: "no"
save_steps: save_steps:
value: 2000 value: 2000
metric_for_best_model: metric_for_best_model:
value: accuracy value: accuracy
freeze_base_model:
values:
- false
- true
group_by_length:
value: false # TODO(SG): batch by length
push_to_hub: push_to_hub:
value: false value: false
use_weighted_layer_sum:
value: false
freeze_base_model:
value: true
max_samples_per_label:
value: 10000
save_to_disk:
value: "/raid/yoach/tmp_dataset_accents/"
temporary_save_to_disk:
value: "/raid/yoach/tmp_hidden_states/"
use_last_embedding_layer:
value: true
filter_threshold:
value: "0.001"
program: run_audio_classification.py program: run_audio_classification.py
project: mms-lid-accent-classification project: mms-lid-accent-classification-v2
\ No newline at end of file
...@@ -21,7 +21,9 @@ import sys ...@@ -21,7 +21,9 @@ import sys
from collections import Counter from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from random import randint from random import randint
from typing import List, Optional, Union from typing import List, Optional, Union, Dict
import torch
import datasets import datasets
import evaluate import evaluate
...@@ -41,7 +43,14 @@ from transformers import ( ...@@ -41,7 +43,14 @@ from transformers import (
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Model
from transformers import Wav2Vec2BertForSequenceClassification, Wav2Vec2BertModel
from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn
from torch.nn import CrossEntropyLoss
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,6 +74,81 @@ def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int ...@@ -65,6 +74,81 @@ def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int
return wav return wav
return wav[0:sample_length] return wav[0:sample_length]
class SequenceClassificationModel(Wav2Vec2ForSequenceClassification):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
)
self.wav2vec2 = Wav2Vec2Model(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# To bypass w2v2
self.compute_w2v2 = True
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_features=None,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
hidden_states=None, # added
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True
if self.compute_w2v2:
outputs = self.wav2vec2(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[_HIDDEN_STATES_START_POSITION][-1] # take last embedding layer
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
else:
pooled_output = hidden_states
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
# This list first defines the accent prefixes, which we use to strip the accent from CV # This list first defines the accent prefixes, which we use to strip the accent from CV
# e.g. England, southern accent, slight west-country expression -> England # e.g. England, southern accent, slight west-country expression -> England
...@@ -84,7 +168,8 @@ STARTS_WITH = [ ...@@ -84,7 +168,8 @@ STARTS_WITH = [
"German", "German",
"Filipino", "Filipino",
"India", "India",
"Irish" "Israeli", "Irish",
"Israeli",
"Italian", "Italian",
"Japanese", "Japanese",
"Kenyan", "Kenyan",
...@@ -253,6 +338,41 @@ def preprocess_labels(label: str) -> str: ...@@ -253,6 +338,41 @@ def preprocess_labels(label: str) -> str:
return label return label
@dataclass
class DataCollatorFeatureExtractorWithPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
feature_extractor: AutoFeatureExtractor
max_length_seconds: int
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios = [torch.tensor(deterministic_subsample(feature["audio"]["array"], max_length=self.max_length_seconds, sample_rate=self.feature_extractor.sampling_rate)).squeeze().numpy() for feature in features]
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest", return_attention_mask=True)
batch["labels"] = torch.tensor([feature["labels_id"] for feature in features])
return batch
@dataclass
class DataCollatorHiddenStatesPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
hidden_states = torch.stack([torch.tensor(feature["hidden_states"]) for feature in features])
batch = {"hidden_states": hidden_states}
batch["labels"] = torch.tensor([feature["labels_id"] for feature in features])
return batch
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
""" """
...@@ -359,6 +479,26 @@ class DataTrainingArguments: ...@@ -359,6 +479,26 @@ class DataTrainingArguments:
default=1.0, default=1.0,
metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."}, metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."},
) )
max_samples_per_label: Optional[int] = field(
default=None,
metadata={
"help": (
"If set, randomly limits the number of samples per label."
)
},
)
save_to_disk: str = field(
default=None,
metadata={
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
)
@dataclass @dataclass
...@@ -369,7 +509,7 @@ class ModelArguments: ...@@ -369,7 +509,7 @@ class ModelArguments:
model_name_or_path: str = field( model_name_or_path: str = field(
default="facebook/wav2vec2-base", default="facebook/wav2vec2-base",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models. Only works with Wav2Vec2 models"},
) )
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
...@@ -465,6 +605,8 @@ class ModelArguments: ...@@ -465,6 +605,8 @@ class ModelArguments:
metadata={"help": "Length of vector span to mask along the feature axis."}, metadata={"help": "Length of vector span to mask along the feature axis."},
) )
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
use_weighted_layer_sum: bool = field(default=False, metadata={"help": "Whether to use a weighted average of layer outputs with learned weights."})
use_last_embedding_layer: bool = field(default=False, metadata={"help": "Whether to use the last layer hidden state. Only work with W2V model."})
def convert_dataset_str_to_list( def convert_dataset_str_to_list(
...@@ -661,10 +803,16 @@ def main(): ...@@ -661,10 +803,16 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
) )
if data_args.save_to_disk is not None:
os.makedirs(data_args.save_to_disk, exist_ok=True)
# assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
if dataset_was_precomputed:
raw_datasets = datasets.load_from_disk(data_args.save_to_disk)
else:
# Initialize our dataset and prepare it for the audio classification task. # Initialize our dataset and prepare it for the audio classification task.
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
# set seed for determinism
set_seed(training_args.seed)
if training_args.do_train: if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets( raw_datasets["train"] = load_multiple_datasets(
...@@ -742,12 +890,16 @@ def main(): ...@@ -742,12 +890,16 @@ def main():
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
) )
feature_extractor_input_name = feature_extractor.model_input_names[0]
if not dataset_was_precomputed:
# `datasets` takes care of automatically loading and resampling the audio, # `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate. # so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column( raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
) )
with training_args.main_process_first():
if training_args.do_train: if training_args.do_train:
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
raw_datasets["train"] = ( raw_datasets["train"] = (
...@@ -763,23 +915,33 @@ def main(): ...@@ -763,23 +915,33 @@ def main():
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0] model_input_name = feature_extractor.model_input_names[0]
def prepare_dataset(batch): if not dataset_was_precomputed:
batch["length"] = len(batch["audio"]["array"]) def prepare_dataset(audio, labels):
batch["labels"] = preprocess_labels(batch["labels"]) batch = {}
batch["length"] = len(audio["array"])
batch["labels"] = preprocess_labels(labels)
return batch return batch
raw_datasets = raw_datasets.map( with training_args.main_process_first():
tmp_datasets = raw_datasets.map(
prepare_dataset, prepare_dataset,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
input_columns=["audio", "labels"],
remove_columns=[col for col in next(iter(raw_datasets.values())).column_names if col != "labels"], # this is a trick to avoid to rewrite the entire audio column which takes ages
desc="Computing audio length", desc="Computing audio length",
) )
for split in raw_datasets:
raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(["labels"]), tmp_datasets[split]], axis=1)
# filter training data with inputs < min_input_length # filter training data with inputs < min_input_length
min_input_length = data_args.min_length_seconds * sampling_rate min_input_length = data_args.min_length_seconds * sampling_rate
if not dataset_was_precomputed:
def is_audio_valid(input_length): def is_audio_valid(input_length):
return input_length > min_input_length return input_length > min_input_length
with training_args.main_process_first():
raw_datasets = raw_datasets.filter( raw_datasets = raw_datasets.filter(
is_audio_valid, is_audio_valid,
input_columns=["length"], input_columns=["length"],
...@@ -791,6 +953,7 @@ def main(): ...@@ -791,6 +953,7 @@ def main():
def is_label_valid(label): def is_label_valid(label):
return label != "Unknown" return label != "Unknown"
with training_args.main_process_first():
raw_datasets = raw_datasets.filter( raw_datasets = raw_datasets.filter(
is_label_valid, is_label_valid,
input_columns=["labels"], input_columns=["labels"],
...@@ -798,6 +961,20 @@ def main(): ...@@ -798,6 +961,20 @@ def main():
desc="Filtering by labels", desc="Filtering by labels",
) )
if training_args.do_train and data_args.max_samples_per_label:
label_names = set(raw_datasets["train"]["labels"])
labels = np.array(raw_datasets["train"]["labels"])
indices = np.arange(len(labels))
indices_to_keep = []
set_seed(training_args.seed)
for label in label_names:
label_indices = indices[labels==label]
label_indices = np.random.choice(label_indices, size=min(data_args.max_samples_per_label, len(label_indices)), replace=False)
indices_to_keep.extend(np.random.choice(label_indices, size=min(data_args.max_samples_per_label, len(label_indices)), replace=False))
with training_args.main_process_first():
raw_datasets["train"] = raw_datasets["train"].select(indices_to_keep).shuffle(seed=training_args.seed)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered) # Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq # sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"]) count_labels_dict = Counter(raw_datasets["train"]["labels"])
...@@ -819,6 +996,9 @@ def main(): ...@@ -819,6 +996,9 @@ def main():
def is_label_valid(label): def is_label_valid(label):
return label not in labels_to_remove return label not in labels_to_remove
logger.info(f"Labels removed: {labels_to_remove}")
with training_args.main_process_first():
raw_datasets = raw_datasets.filter( raw_datasets = raw_datasets.filter(
is_label_valid, is_label_valid,
input_columns=["labels"], input_columns=["labels"],
...@@ -829,37 +1009,42 @@ def main(): ...@@ -829,37 +1009,42 @@ def main():
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"]) set_labels = set(raw_datasets["train"]["labels"])
if training_args.do_eval: if training_args.do_eval:
logger.info(f"The following accents are present in the eval set but not the test set: {set(raw_datasets['eval']['labels']) - set_labels}")
set_labels = set_labels.union(set(raw_datasets["eval"]["labels"])) set_labels = set_labels.union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {} label2id, id2label = {}, {}
for i, label in enumerate(set(set_labels)): for i, label in enumerate(sorted(list(set_labels))):
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
def train_transforms(batch): with training_args.main_process_first():
"""Apply train_transforms across a batch.""" raw_datasets = raw_datasets.map(
subsampled_wavs = [] lambda label: {"labels_id": int(label2id[label])},
for audio in batch["audio"]: input_columns=["labels"],
wav = deterministic_subsample( num_proc=data_args.preprocessing_num_workers,
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate desc="Apply labels id",
) )
subsampled_wavs.append(wav) else:
inputs = feature_extractor( label2id, id2label = {}, {}
subsampled_wavs, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
)
output_batch = {
model_input_name: inputs.get(model_input_name),
"attention_mask": inputs.get("attention_mask"),
"labels": [int(label2id[label]) for label in batch["labels"]],
}
return output_batch
# TODO: slow - probably not best
if training_args.do_train: if training_args.do_train:
# Set the training transforms for sample in raw_datasets["train"]:
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False) candidate = label2id.get(sample["labels"])
if candidate is not None and candidate != sample["labels_id"]:
print(f"issue {candidate} should be {sample['labels_id']}")
label2id[sample["labels"]] = sample["labels_id"]
if training_args.do_eval: if training_args.do_eval:
# Set the validation transforms for sample in raw_datasets["eval"]:
raw_datasets["eval"].set_transform(train_transforms, output_all_columns=False) candidate = label2id.get(sample["labels"])
if candidate is not None and candidate != sample["labels_id"]:
print(f"issue {candidate} should be {sample['labels_id']}")
label2id[sample["labels"]] = sample["labels_id"]
# if training_args.do_train:
# label2id = dict(zip(raw_datasets["train"]["labels"], raw_datasets["train"]["labels_id"]))
# if training_args.do_eval:
# label2id.update(dict(zip(raw_datasets["eval"]["labels"], raw_datasets["eval"]["labels_id"])))
id2label = {str(val): key for (key, val) in label2id.items()}
# Load the accuracy metric from the datasets package # Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
...@@ -895,10 +1080,12 @@ def main(): ...@@ -895,10 +1080,12 @@ def main():
"mask_feature_length": model_args.mask_feature_length, "mask_feature_length": model_args.mask_feature_length,
"layerdrop": model_args.layerdrop, "layerdrop": model_args.layerdrop,
"activation_dropout": model_args.activation_dropout, "activation_dropout": model_args.activation_dropout,
"use_weighted_layer_sum": model_args.use_weighted_layer_sum,
} }
) )
model = AutoModelForAudioClassification.from_pretrained( if model_args.use_last_embedding_layer:
model = SequenceClassificationModel.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config, config=config,
...@@ -909,6 +1096,19 @@ def main(): ...@@ -909,6 +1096,19 @@ def main():
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
if model_args.freeze_base_model:
model.compute_w2v2 = False
else:
model = AutoModelForAudioClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder for wav2vec2-style models # freeze the convolutional waveform encoder for wav2vec2-style models
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
if hasattr(model, "freeze_feature_encoder"): if hasattr(model, "freeze_feature_encoder"):
...@@ -928,15 +1128,91 @@ def main(): ...@@ -928,15 +1128,91 @@ def main():
else: else:
raise ValueError("Method for freezing the base module of the audio encoder is not defined") raise ValueError("Method for freezing the base module of the audio encoder is not defined")
if model_args.freeze_base_model and model_args.use_last_embedding_layer:
if not dataset_was_precomputed:
# precomputing hidden states
from torch.utils.data import DataLoader
from accelerate import Accelerator
_HIDDEN_STATES_START_POSITION = 2
if training_args.fp16:
mixed_precision = "fp16"
elif training_args.bf16:
mixed_precision = "bf16"
else:
mixed_precision = "no"
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
project_dir=training_args.output_dir,
)
audio_data_collator = DataCollatorFeatureExtractorWithPadding(feature_extractor, max_length_seconds=data_args.max_length_seconds, feature_extractor_input_name=feature_extractor_input_name)
for split in raw_datasets:
data_loader = DataLoader(
raw_datasets[split],
batch_size=training_args.per_device_train_batch_size, # TODO: chose another one
collate_fn=audio_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_encoder_outputs = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.wav2vec2.to(batch[feature_extractor_input_name].device)
with torch.no_grad():
encoder_outputs = model.wav2vec2(batch[feature_extractor_input_name], attention_mask=batch.get("attention_mask", None), output_hidden_states=True)
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION][-1]
if batch.get("attention_mask", None) is None:
hidden_states = hidden_states.mean(dim=1)
else:
padding_mask = model._get_feature_vector_attention_mask(hidden_states.shape[1], batch.get("attention_mask", None))
hidden_states[~padding_mask] = 0.0
hidden_states = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
encoder_outputs = accelerator.gather_for_metrics(hidden_states)
# TODO: check it works multi device
if accelerator.is_main_process:
all_encoder_outputs.extend(encoder_outputs.to("cpu").numpy())
if accelerator.is_main_process:
tmp_hidden_states = Dataset.from_dict({"hidden_states": all_encoder_outputs})
tmp_hidden_states.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split))
accelerator.wait_for_everyone()
del all_encoder_outputs
tmp_hidden_states = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first():
raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns("audio"), tmp_hidden_states], axis=1)
accelerator.free_memory()
del data_loader, batch, accelerator
data_collator = DataCollatorHiddenStatesPadding()
else:
data_collator = DataCollatorFeatureExtractorWithPadding(feature_extractor, max_length_seconds=data_args.max_length_seconds)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
raw_datasets.save_to_disk(data_args.save_to_disk)
logger.info(f"Dataset saved at {data_args.save_to_disk}. Be careful of changing data parameters, which won't change the current saved dataset if reloaded.")
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
data_collator=data_collator,
train_dataset=raw_datasets["train"] if training_args.do_train else None, train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None, eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=feature_extractor, tokenizer=feature_extractor,
) )
ignore_keys = ["hidden_states","attentions"] if model_args.use_weighted_layer_sum else None
# Training # Training
if training_args.do_train: if training_args.do_train:
...@@ -945,7 +1221,7 @@ def main(): ...@@ -945,7 +1221,7 @@ def main():
checkpoint = training_args.resume_from_checkpoint checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None: elif last_checkpoint is not None:
checkpoint = last_checkpoint checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint) train_result = trainer.train(resume_from_checkpoint=checkpoint, ignore_keys_for_eval=ignore_keys)
trainer.save_model() trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
...@@ -953,7 +1229,7 @@ def main(): ...@@ -953,7 +1229,7 @@ def main():
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate() metrics = trainer.evaluate(ignore_keys=ignore_keys)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment