"vscode:/vscode.git/clone" did not exist on "0a0c952759968e95aaa98c5506ef481bcce695d2"
Commit 64cfa64e authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

update audio classification script

parent 10ef6f6c
#!/usr/bin/env bash #!/usr/bin/env bash
python run_audio_classification.py \ CUDA_VISIBLE_DEVICES=2 python run_audio_classification_one_layer.py \
--model_name_or_path "facebook/mms-lid-126" \ --model_name_or_path "facebook/mms-lid-4017" \
--train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \ --train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--train_dataset_config_name "default" \ --train_dataset_config_name "default" \
--train_split_name "train" \ --train_split_name "train" \
...@@ -10,11 +10,11 @@ python run_audio_classification.py \ ...@@ -10,11 +10,11 @@ python run_audio_classification.py \
--eval_dataset_config_name "default" \ --eval_dataset_config_name "default" \
--eval_split_name "test" \ --eval_split_name "test" \
--eval_label_column_name "labels" \ --eval_label_column_name "labels" \
--output_dir "./" \ --output_dir "./tmp/" \
--do_train \ --do_train \
--do_eval \ --do_eval \
--overwrite_output_dir \ --overwrite_output_dir \
--remove_unused_columns False \ --remove_unused_columns false \
--fp16 \ --fp16 \
--fp16_full_eval \ --fp16_full_eval \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
...@@ -30,9 +30,11 @@ python run_audio_classification.py \ ...@@ -30,9 +30,11 @@ python run_audio_classification.py \
--logging_strategy "steps" \ --logging_strategy "steps" \
--logging_steps 10 \ --logging_steps 10 \
--evaluation_strategy "steps" \ --evaluation_strategy "steps" \
--eval_steps 500 \ --eval_steps 300 \
--save_strategy "no" \ --save_strategy "no" \
--save_steps 2000 \ --save_steps 2000 \
--freeze_base_model True \ --freeze_base_model true \
--push_to_hub False \ --freeze_feature_encoder true \
--trust_remote_code --push_to_hub false \
--trust_remote_code \
--use_weighted_layer_sum true \
command: command:
- python3 - python3
- ${program} - ${program}
- --load_best_model_at_end
- --fp16 - --fp16
- --fp16_full_eval - --fp16_full_eval
- --do_train - --do_train
- --do_eval - --do_eval
- --trust_remote_code - --trust_remote_code
- --overwrite_output_dir - --overwrite_output_dir
- --ignore_mismatched_sizes
- --gradient_checkpointing
- ${args} - ${args}
method: grid method: random
metric: metric:
goal: maximize goal: maximize
name: eval/accuracy name: eval/accuracy
parameters: parameters:
model_name_or_path: model_name_or_path:
value: facebook/mms-lid-126 value: facebook/mms-lid-4017
train_dataset_name: train_dataset_name:
value: stable-speech/concatenated-accent-dataset value: "stable-speech/concatenated-normalized-accent-dataset+stable-speech/concatenated-common-voice-15-accented"
train_dataset_config_name: train_dataset_config_name:
value: default value: "default+default"
train_split_name: train_split_name:
value: train value: "train+train"
train_label_column_name: train_label_column_name:
value: labels value: "labels+labels"
eval_dataset_name: eval_dataset_name:
value: stable-speech/concatenated-accent-dataset value: stable-speech/concatenated-normalized-accent-dataset
eval_dataset_config_name: eval_dataset_config_name:
value: default value: default
eval_split_name: eval_split_name:
...@@ -35,31 +32,38 @@ parameters: ...@@ -35,31 +32,38 @@ parameters:
eval_label_column_name: eval_label_column_name:
value: labels value: labels
output_dir: output_dir:
value: ./ value: "/raid/yoach/tmp/"
remove_unused_columns: remove_unused_columns:
value: false value: false
learning_rate: learning_rate:
value: 1e-4 distribution: log_uniform_values
min: 3e-6
max: 0.01
lr_scheduler_type: lr_scheduler_type:
value: constant_with_warmup value: constant
max_length_seconds: max_length_seconds:
value: 20 # give some data diversity for longer audio samples value: 20 # give some data diversity for longer audio samples
min_length_seconds: min_length_seconds:
value: 7 value: 5
attention_mask: attention_mask:
value: true values:
warmup_steps: - true
value: 100 num_train_epochs:
max_steps: values:
value: 2000 - 2
- 5
- 10
- 20
- 40
- 60
per_device_train_batch_size: per_device_train_batch_size:
value: 32 value: 32
per_device_eval_batch_size: per_device_eval_batch_size:
value: 16 value: 32
preprocessing_num_workers: preprocessing_num_workers:
value: 4 value: 8
dataloader_num_workers: dataloader_num_workers:
value: 4 value: 8
logging_strategy: logging_strategy:
value: steps value: steps
logging_steps: logging_steps:
...@@ -67,20 +71,28 @@ parameters: ...@@ -67,20 +71,28 @@ parameters:
evaluation_strategy: evaluation_strategy:
value: steps value: steps
eval_steps: eval_steps:
value: 1000 value: 2000
save_strategy: save_strategy:
value: steps value: "no"
save_steps: save_steps:
value: 2000 value: 2000
metric_for_best_model: metric_for_best_model:
value: accuracy value: accuracy
freeze_base_model:
values:
- false
- true
group_by_length:
value: false # TODO(SG): batch by length
push_to_hub: push_to_hub:
value: false value: false
use_weighted_layer_sum:
value: false
freeze_base_model:
value: true
max_samples_per_label:
value: 10000
save_to_disk:
value: "/raid/yoach/tmp_dataset_accents/"
temporary_save_to_disk:
value: "/raid/yoach/tmp_hidden_states/"
use_last_embedding_layer:
value: true
filter_threshold:
value: "0.001"
program: run_audio_classification.py program: run_audio_classification.py
project: mms-lid-accent-classification project: mms-lid-accent-classification-v2
\ No newline at end of file
...@@ -21,7 +21,9 @@ import sys ...@@ -21,7 +21,9 @@ import sys
from collections import Counter from collections import Counter
from dataclasses import dataclass, field from dataclasses import dataclass, field
from random import randint from random import randint
from typing import List, Optional, Union from typing import List, Optional, Union, Dict
import torch
import datasets import datasets
import evaluate import evaluate
...@@ -41,7 +43,14 @@ from transformers import ( ...@@ -41,7 +43,14 @@ from transformers import (
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Model
from transformers import Wav2Vec2BertForSequenceClassification, Wav2Vec2BertModel
from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION
from transformers.modeling_outputs import SequenceClassifierOutput
from torch import nn
from torch.nn import CrossEntropyLoss
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,6 +74,81 @@ def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int ...@@ -65,6 +74,81 @@ def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int
return wav return wav
return wav[0:sample_length] return wav[0:sample_length]
class SequenceClassificationModel(Wav2Vec2ForSequenceClassification):
def __init__(self, config):
super().__init__(config)
if hasattr(config, "add_adapter") and config.add_adapter:
raise ValueError(
"Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
)
self.wav2vec2 = Wav2Vec2Model(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# To bypass w2v2
self.compute_w2v2 = True
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_features=None,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
hidden_states=None, # added
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = True
if self.compute_w2v2:
outputs = self.wav2vec2(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[_HIDDEN_STATES_START_POSITION][-1] # take last embedding layer
if attention_mask is None:
pooled_output = hidden_states.mean(dim=1)
else:
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
hidden_states[~padding_mask] = 0.0
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
else:
pooled_output = hidden_states
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
# This list first defines the accent prefixes, which we use to strip the accent from CV # This list first defines the accent prefixes, which we use to strip the accent from CV
# e.g. England, southern accent, slight west-country expression -> England # e.g. England, southern accent, slight west-country expression -> England
...@@ -84,7 +168,8 @@ STARTS_WITH = [ ...@@ -84,7 +168,8 @@ STARTS_WITH = [
"German", "German",
"Filipino", "Filipino",
"India", "India",
"Irish" "Israeli", "Irish",
"Israeli",
"Italian", "Italian",
"Japanese", "Japanese",
"Kenyan", "Kenyan",
...@@ -252,6 +337,41 @@ def preprocess_labels(label: str) -> str: ...@@ -252,6 +337,41 @@ def preprocess_labels(label: str) -> str:
label = ACCENT_MAPPING[label] label = ACCENT_MAPPING[label]
return label return label
@dataclass
class DataCollatorFeatureExtractorWithPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
feature_extractor: AutoFeatureExtractor
max_length_seconds: int
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios = [torch.tensor(deterministic_subsample(feature["audio"]["array"], max_length=self.max_length_seconds, sample_rate=self.feature_extractor.sampling_rate)).squeeze().numpy() for feature in features]
batch = self.feature_extractor(audios, return_tensors="pt", padding="longest", return_attention_mask=True)
batch["labels"] = torch.tensor([feature["labels_id"] for feature in features])
return batch
@dataclass
class DataCollatorHiddenStatesPadding:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
hidden_states = torch.stack([torch.tensor(feature["hidden_states"]) for feature in features])
batch = {"hidden_states": hidden_states}
batch["labels"] = torch.tensor([feature["labels_id"] for feature in features])
return batch
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
...@@ -359,6 +479,26 @@ class DataTrainingArguments: ...@@ -359,6 +479,26 @@ class DataTrainingArguments:
default=1.0, default=1.0,
metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."}, metadata={"help": "Filter labels that occur less than `filter_threshold` percent in the training/eval data."},
) )
max_samples_per_label: Optional[int] = field(
default=None,
metadata={
"help": (
"If set, randomly limits the number of samples per label."
)
},
)
save_to_disk: str = field(
default=None,
metadata={
"help": "If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
)
temporary_save_to_disk: str = field(
default=None,
metadata={
"help": "Temporarily save audio labels here."
}
)
@dataclass @dataclass
...@@ -369,7 +509,7 @@ class ModelArguments: ...@@ -369,7 +509,7 @@ class ModelArguments:
model_name_or_path: str = field( model_name_or_path: str = field(
default="facebook/wav2vec2-base", default="facebook/wav2vec2-base",
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models. Only works with Wav2Vec2 models"},
) )
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
...@@ -465,6 +605,8 @@ class ModelArguments: ...@@ -465,6 +605,8 @@ class ModelArguments:
metadata={"help": "Length of vector span to mask along the feature axis."}, metadata={"help": "Length of vector span to mask along the feature axis."},
) )
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
use_weighted_layer_sum: bool = field(default=False, metadata={"help": "Whether to use a weighted average of layer outputs with learned weights."})
use_last_embedding_layer: bool = field(default=False, metadata={"help": "Whether to use the last layer hidden state. Only work with W2V model."})
def convert_dataset_str_to_list( def convert_dataset_str_to_list(
...@@ -661,75 +803,81 @@ def main(): ...@@ -661,75 +803,81 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
) )
# Initialize our dataset and prepare it for the audio classification task. if data_args.save_to_disk is not None:
raw_datasets = DatasetDict() os.makedirs(data_args.save_to_disk, exist_ok=True)
# set seed for determinism
set_seed(training_args.seed) # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
if training_args.do_train: if dataset_was_precomputed:
raw_datasets["train"] = load_multiple_datasets( raw_datasets = datasets.load_from_disk(data_args.save_to_disk)
data_args.train_dataset_name, else:
data_args.train_dataset_config_name, # Initialize our dataset and prepare it for the audio classification task.
splits=data_args.train_split_name, raw_datasets = DatasetDict()
label_column_names=data_args.train_label_column_name,
dataset_samples=data_args.train_dataset_samples, if training_args.do_train:
seed=training_args.seed, raw_datasets["train"] = load_multiple_datasets(
cache_dir=model_args.cache_dir, data_args.train_dataset_name,
token=True if model_args.token else None, data_args.train_dataset_config_name,
trust_remote_code=model_args.trust_remote_code, splits=data_args.train_split_name,
num_proc=data_args.preprocessing_num_workers, label_column_names=data_args.train_label_column_name,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode dataset_samples=data_args.train_dataset_samples,
) seed=training_args.seed,
if training_args.do_eval:
dataset_names_dict = convert_dataset_str_to_list(
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
(
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name
),
splits=data_args.eval_split_name,
label_column_names=data_args.eval_label_column_name,
)
all_eval_splits = []
# load multiple eval sets
for dataset_dict in dataset_names_dict:
pretty_name = (
f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
if len(dataset_names_dict) > 1
else "eval"
)
all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
# streaming=data_args.streaming, # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
) )
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features: if training_args.do_eval:
raise ValueError( dataset_names_dict = convert_dataset_str_to_list(
f"--label_column_name {data_args.eval_label_column_name} not found in dataset '{data_args.dataset_name}'. " data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
"Make sure to set `--label_column_name` to the correct text column - one of " (
f"{', '.join(raw_datasets['train'].column_names)}." data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name
),
splits=data_args.eval_split_name,
label_column_names=data_args.eval_label_column_name,
)
all_eval_splits = []
# load multiple eval sets
for dataset_dict in dataset_names_dict:
pretty_name = (
f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
if len(dataset_names_dict) > 1
else "eval"
) )
elif dataset_dict["label_column_name"] != "labels": all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column( raw_datasets[pretty_name] = load_dataset(
dataset_dict["label_column_name"], "labels" dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
num_proc=data_args.preprocessing_num_workers,
# streaming=data_args.streaming,
)
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {data_args.eval_label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
elif dataset_dict["label_column_name"] != "labels":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["label_column_name"], "labels"
)
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "labels"}
) )
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "labels"}
)
if not training_args.do_train and not training_args.do_eval: if not training_args.do_train and not training_args.do_eval:
raise ValueError( raise ValueError(
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed." "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
) )
# Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over # Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
# transformer outputs in the classifier, but it doesn't always lead to better accuracy # transformer outputs in the classifier, but it doesn't always lead to better accuracy
...@@ -741,126 +889,163 @@ def main(): ...@@ -741,126 +889,163 @@ def main():
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
) )
feature_extractor_input_name = feature_extractor.model_input_names[0]
if not dataset_was_precomputed:
# `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# `datasets` takes care of automatically loading and resampling the audio, with training_args.main_process_first():
# so we just need to set the correct target sampling rate. if training_args.do_train:
raw_datasets = raw_datasets.cast_column( if data_args.max_train_samples is not None:
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) raw_datasets["train"] = (
) raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
if training_args.do_train:
if data_args.max_train_samples is not None:
raw_datasets["train"] = (
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
if training_args.do_eval: if training_args.do_eval:
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
raw_datasets["eval"] = ( raw_datasets["eval"] = (
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
) )
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0] model_input_name = feature_extractor.model_input_names[0]
def prepare_dataset(batch): if not dataset_was_precomputed:
batch["length"] = len(batch["audio"]["array"]) def prepare_dataset(audio, labels):
batch["labels"] = preprocess_labels(batch["labels"]) batch = {}
return batch batch["length"] = len(audio["array"])
batch["labels"] = preprocess_labels(labels)
return batch
raw_datasets = raw_datasets.map( with training_args.main_process_first():
prepare_dataset, tmp_datasets = raw_datasets.map(
num_proc=data_args.preprocessing_num_workers, prepare_dataset,
desc="Computing audio length", num_proc=data_args.preprocessing_num_workers,
) input_columns=["audio", "labels"],
remove_columns=[col for col in next(iter(raw_datasets.values())).column_names if col != "labels"], # this is a trick to avoid to rewrite the entire audio column which takes ages
desc="Computing audio length",
)
for split in raw_datasets:
raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns(["labels"]), tmp_datasets[split]], axis=1)
# filter training data with inputs < min_input_length # filter training data with inputs < min_input_length
min_input_length = data_args.min_length_seconds * sampling_rate min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(input_length): if not dataset_was_precomputed:
return input_length > min_input_length def is_audio_valid(input_length):
return input_length > min_input_length
raw_datasets = raw_datasets.filter(
is_audio_valid,
input_columns=["length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
count_labels_dict = sorted(count_labels_dict.items(), key=lambda item: (-item[1], item[0]))
labels, frequencies = zip(*count_labels_dict)
total_labels = sum(frequencies)
labels_to_remove = []
logger.info(f"{'Accent':<15} {'Perc.':<5}")
logger.info("-" * 20)
for lab, freq in zip(labels, frequencies):
freq = 100 * freq / total_labels
logger.info(f"{lab:<15} {freq:<5}")
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
if len(labels_to_remove):
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
raw_datasets = raw_datasets.filter( with training_args.main_process_first():
is_label_valid, raw_datasets = raw_datasets.filter(
input_columns=["labels"], is_audio_valid,
num_proc=data_args.preprocessing_num_workers, input_columns=["length"],
desc="Filtering low freq labels", num_proc=data_args.preprocessing_num_workers,
) desc="Filtering by audio length",
# We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"])
if training_args.do_eval:
set_labels = set_labels.union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {}
for i, label in enumerate(set(set_labels)):
label2id[label] = str(i)
id2label[str(i)] = label
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
for audio in batch["audio"]:
wav = deterministic_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
) )
subsampled_wavs.append(wav)
inputs = feature_extractor(
subsampled_wavs, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
)
output_batch = {
model_input_name: inputs.get(model_input_name),
"attention_mask": inputs.get("attention_mask"),
"labels": [int(label2id[label]) for label in batch["labels"]],
}
return output_batch
if training_args.do_train: # filter training data with non-valid labels
# Set the training transforms def is_label_valid(label):
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False) return label != "Unknown"
if training_args.do_eval: with training_args.main_process_first():
# Set the validation transforms raw_datasets = raw_datasets.filter(
raw_datasets["eval"].set_transform(train_transforms, output_all_columns=False) is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
if training_args.do_train and data_args.max_samples_per_label:
label_names = set(raw_datasets["train"]["labels"])
labels = np.array(raw_datasets["train"]["labels"])
indices = np.arange(len(labels))
indices_to_keep = []
set_seed(training_args.seed)
for label in label_names:
label_indices = indices[labels==label]
label_indices = np.random.choice(label_indices, size=min(data_args.max_samples_per_label, len(label_indices)), replace=False)
indices_to_keep.extend(np.random.choice(label_indices, size=min(data_args.max_samples_per_label, len(label_indices)), replace=False))
with training_args.main_process_first():
raw_datasets["train"] = raw_datasets["train"].select(indices_to_keep).shuffle(seed=training_args.seed)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
# sort by freq
count_labels_dict = Counter(raw_datasets["train"]["labels"])
count_labels_dict = sorted(count_labels_dict.items(), key=lambda item: (-item[1], item[0]))
labels, frequencies = zip(*count_labels_dict)
total_labels = sum(frequencies)
labels_to_remove = []
logger.info(f"{'Accent':<15} {'Perc.':<5}")
logger.info("-" * 20)
for lab, freq in zip(labels, frequencies):
freq = 100 * freq / total_labels
logger.info(f"{lab:<15} {freq:<5}")
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
if len(labels_to_remove):
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
logger.info(f"Labels removed: {labels_to_remove}")
with training_args.main_process_first():
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering low freq labels",
)
# We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"])
if training_args.do_eval:
logger.info(f"The following accents are present in the eval set but not the test set: {set(raw_datasets['eval']['labels']) - set_labels}")
set_labels = set_labels.union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {}
for i, label in enumerate(sorted(list(set_labels))):
label2id[label] = str(i)
id2label[str(i)] = label
with training_args.main_process_first():
raw_datasets = raw_datasets.map(
lambda label: {"labels_id": int(label2id[label])},
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Apply labels id",
)
else:
label2id, id2label = {}, {}
# TODO: slow - probably not best
if training_args.do_train:
for sample in raw_datasets["train"]:
candidate = label2id.get(sample["labels"])
if candidate is not None and candidate != sample["labels_id"]:
print(f"issue {candidate} should be {sample['labels_id']}")
label2id[sample["labels"]] = sample["labels_id"]
if training_args.do_eval:
for sample in raw_datasets["eval"]:
candidate = label2id.get(sample["labels"])
if candidate is not None and candidate != sample["labels_id"]:
print(f"issue {candidate} should be {sample['labels_id']}")
label2id[sample["labels"]] = sample["labels_id"]
# if training_args.do_train:
# label2id = dict(zip(raw_datasets["train"]["labels"], raw_datasets["train"]["labels_id"]))
# if training_args.do_eval:
# label2id.update(dict(zip(raw_datasets["eval"]["labels"], raw_datasets["eval"]["labels_id"])))
id2label = {str(val): key for (key, val) in label2id.items()}
# Load the accuracy metric from the datasets package # Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
...@@ -895,20 +1080,35 @@ def main(): ...@@ -895,20 +1080,35 @@ def main():
"mask_feature_length": model_args.mask_feature_length, "mask_feature_length": model_args.mask_feature_length,
"layerdrop": model_args.layerdrop, "layerdrop": model_args.layerdrop,
"activation_dropout": model_args.activation_dropout, "activation_dropout": model_args.activation_dropout,
"use_weighted_layer_sum": model_args.use_weighted_layer_sum,
} }
) )
model = AutoModelForAudioClassification.from_pretrained( if model_args.use_last_embedding_layer:
model_args.model_name_or_path, model = SequenceClassificationModel.from_pretrained(
from_tf=bool(".ckpt" in model_args.model_name_or_path), model_args.model_name_or_path,
config=config, from_tf=bool(".ckpt" in model_args.model_name_or_path),
cache_dir=model_args.cache_dir, config=config,
revision=model_args.model_revision, cache_dir=model_args.cache_dir,
token=model_args.token, revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code, token=model_args.token,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, trust_remote_code=model_args.trust_remote_code,
) ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
if model_args.freeze_base_model:
model.compute_w2v2 = False
else:
model = AutoModelForAudioClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder for wav2vec2-style models # freeze the convolutional waveform encoder for wav2vec2-style models
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
if hasattr(model, "freeze_feature_encoder"): if hasattr(model, "freeze_feature_encoder"):
...@@ -928,15 +1128,91 @@ def main(): ...@@ -928,15 +1128,91 @@ def main():
else: else:
raise ValueError("Method for freezing the base module of the audio encoder is not defined") raise ValueError("Method for freezing the base module of the audio encoder is not defined")
if model_args.freeze_base_model and model_args.use_last_embedding_layer:
if not dataset_was_precomputed:
# precomputing hidden states
from torch.utils.data import DataLoader
from accelerate import Accelerator
_HIDDEN_STATES_START_POSITION = 2
if training_args.fp16:
mixed_precision = "fp16"
elif training_args.bf16:
mixed_precision = "bf16"
else:
mixed_precision = "no"
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
project_dir=training_args.output_dir,
)
audio_data_collator = DataCollatorFeatureExtractorWithPadding(feature_extractor, max_length_seconds=data_args.max_length_seconds, feature_extractor_input_name=feature_extractor_input_name)
for split in raw_datasets:
data_loader = DataLoader(
raw_datasets[split],
batch_size=training_args.per_device_train_batch_size, # TODO: chose another one
collate_fn=audio_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_encoder_outputs = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.wav2vec2.to(batch[feature_extractor_input_name].device)
with torch.no_grad():
encoder_outputs = model.wav2vec2(batch[feature_extractor_input_name], attention_mask=batch.get("attention_mask", None), output_hidden_states=True)
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION][-1]
if batch.get("attention_mask", None) is None:
hidden_states = hidden_states.mean(dim=1)
else:
padding_mask = model._get_feature_vector_attention_mask(hidden_states.shape[1], batch.get("attention_mask", None))
hidden_states[~padding_mask] = 0.0
hidden_states = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
encoder_outputs = accelerator.gather_for_metrics(hidden_states)
# TODO: check it works multi device
if accelerator.is_main_process:
all_encoder_outputs.extend(encoder_outputs.to("cpu").numpy())
if accelerator.is_main_process:
tmp_hidden_states = Dataset.from_dict({"hidden_states": all_encoder_outputs})
tmp_hidden_states.save_to_disk(os.path.join(data_args.temporary_save_to_disk, split))
accelerator.wait_for_everyone()
del all_encoder_outputs
tmp_hidden_states = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
with accelerator.main_process_first():
raw_datasets[split] = concatenate_datasets([raw_datasets[split].remove_columns("audio"), tmp_hidden_states], axis=1)
accelerator.free_memory()
del data_loader, batch, accelerator
data_collator = DataCollatorHiddenStatesPadding()
else:
data_collator = DataCollatorFeatureExtractorWithPadding(feature_extractor, max_length_seconds=data_args.max_length_seconds)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
raw_datasets.save_to_disk(data_args.save_to_disk)
logger.info(f"Dataset saved at {data_args.save_to_disk}. Be careful of changing data parameters, which won't change the current saved dataset if reloaded.")
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
data_collator=data_collator,
train_dataset=raw_datasets["train"] if training_args.do_train else None, train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None, eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=feature_extractor, tokenizer=feature_extractor,
) )
ignore_keys = ["hidden_states","attentions"] if model_args.use_weighted_layer_sum else None
# Training # Training
if training_args.do_train: if training_args.do_train:
...@@ -945,7 +1221,7 @@ def main(): ...@@ -945,7 +1221,7 @@ def main():
checkpoint = training_args.resume_from_checkpoint checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None: elif last_checkpoint is not None:
checkpoint = last_checkpoint checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint) train_result = trainer.train(resume_from_checkpoint=checkpoint, ignore_keys_for_eval=ignore_keys)
trainer.save_model() trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
...@@ -953,7 +1229,7 @@ def main(): ...@@ -953,7 +1229,7 @@ def main():
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate() metrics = trainer.evaluate(ignore_keys=ignore_keys)
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment