Commit 4664d695 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

audio class on multiple datasets

parent 6089d39b
...@@ -17,17 +17,16 @@ ...@@ -17,17 +17,16 @@
import logging import logging
import os import os
import sys import sys
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from random import randint from random import randint
from typing import Optional from typing import List, Optional, Union
import datasets import datasets
import evaluate import evaluate
import numpy as np import numpy as np
from datasets import DatasetDict, load_dataset
import transformers import transformers
from datasets import Dataset, DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
from tqdm import tqdm
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoFeatureExtractor, AutoFeatureExtractor,
...@@ -39,7 +38,6 @@ from transformers import ( ...@@ -39,7 +38,6 @@ from transformers import (
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -66,36 +64,53 @@ class DataTrainingArguments: ...@@ -66,36 +64,53 @@ class DataTrainingArguments:
the command line. the command line.
""" """
dataset_name: Optional[str] = field(default=None, metadata={"help": "Name of a dataset from the datasets package"}) train_dataset_name: str = field(
dataset_config_name: Optional[str] = field( default=None,
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} metadata={
"help": "The name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
) )
train_file: Optional[str] = field( train_dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "A file containing the training audio paths and labels."} default=None,
metadata={
"help": "The configuration name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset configs by a '+' symbol."
},
) )
eval_file: Optional[str] = field( train_dataset_samples: str = field(
default=None, metadata={"help": "A file containing the validation audio paths and labels."} default=None,
metadata={
"help": "Number of samples in the training data. Load and combine "
"multiple datasets by separating dataset samples by a '+' symbol."
},
) )
train_split_name: str = field( eval_dataset_name: str = field(
default="train", default=None,
metadata={ metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" "help": "The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
}, },
) )
eval_split_name: str = field( eval_dataset_config_name: Optional[str] = field(
default="validation", default=None,
metadata={ metadata={
"help": ( "help": "The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
"The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
)
}, },
) )
audio_column_name: str = field( audio_column_name: str = field(
default="audio", default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
) )
label_column_name: str = field( train_label_column_name: str = field(
default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"} default="label",
metadata={
"help": "The name of the dataset column containing the labels in the train set. Defaults to 'label'"
},
)
eval_label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels in the eval set. Defaults to 'label'"},
) )
max_train_samples: Optional[int] = field( max_train_samples: Optional[int] = field(
default=None, default=None,
...@@ -159,12 +174,6 @@ class ModelArguments: ...@@ -159,12 +174,6 @@ class ModelArguments:
) )
}, },
) )
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field( trust_remote_code: bool = field(
default=False, default=False,
metadata={ metadata={
...@@ -175,29 +184,153 @@ class ModelArguments: ...@@ -175,29 +184,153 @@ class ModelArguments:
) )
}, },
) )
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
ignore_mismatched_sizes: bool = field( ignore_mismatched_sizes: bool = field(
default=False, default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
) )
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder: def convert_dataset_str_to_list(
warnings.warn( dataset_names,
"The argument `--freeze_feature_extractor` is deprecated and " dataset_config_names,
"will be removed in a future version. Use `--freeze_feature_encoder` " splits=None,
"instead. Setting `freeze_feature_encoder==True`.", label_column_names=None,
FutureWarning, dataset_samples=None,
default_split="train",
):
if isinstance(dataset_names, str):
dataset_names = dataset_names.split("+")
# we assume that all the datasets we're using derive from the distil-whisper org on the Hub - prepend the org name if necessary
for i in range(len(dataset_names)):
ds_name = dataset_names[i]
dataset_names[i] = f"distil-whisper/{ds_name}" if "/" not in ds_name else ds_name
dataset_config_names = dataset_config_names.split("+")
splits = splits.split("+") if splits is not None else None
label_column_names = label_column_names.split("+") if label_column_names is not None else None
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
if len(dataset_names) != len(dataset_config_names):
raise ValueError(
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(dataset_config_names)} configs."
)
if splits is not None and len(splits) != len(dataset_names):
raise ValueError(
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
)
if label_column_names is not None and len(label_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one label column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(label_column_names)} label column names."
)
if dataset_samples is not None:
if len(dataset_samples) != len(dataset_names):
raise ValueError(
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
f"{len(dataset_samples)} samples."
)
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
else:
dataset_samples = [None] * len(dataset_names)
label_column_names = (
label_column_names if label_column_names is not None else ["label" for _ in range(len(dataset_names))]
)
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
dataset_names_dict = []
for i, ds_name in enumerate(dataset_names):
dataset_names_dict.append(
{
"name": ds_name,
"config": dataset_config_names[i],
"split": splits[i],
"label_column_name": label_column_names[i],
"samples": dataset_samples[i],
}
)
return dataset_names_dict
def load_multiple_datasets(
dataset_names: Union[List, str],
dataset_config_names: Union[List, str],
splits: Optional[Union[List, str]] = None,
label_column_names: Optional[List] = None,
stopping_strategy: Optional[str] = "first_exhausted",
dataset_samples: Optional[Union[List, np.array]] = None,
streaming: Optional[bool] = True,
seed: Optional[int] = None,
audio_column_name: Optional[str] = "audio",
**kwargs,
) -> Union[Dataset, IterableDataset]:
dataset_names_dict = convert_dataset_str_to_list(
dataset_names, dataset_config_names, splits, label_column_names, dataset_samples
)
if dataset_samples is not None:
dataset_samples = [ds_dict["samples"] for ds_dict in dataset_names_dict]
probabilities = np.array(dataset_samples) / np.sum(dataset_samples)
else:
probabilities = None
all_datasets = []
# iterate over the datasets we want to interleave
for dataset_dict in tqdm(dataset_names_dict, desc="Combining datasets..."):
dataset = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
streaming=streaming,
**kwargs,
)
dataset_features = dataset.features.keys()
if audio_column_name not in dataset_features:
raise ValueError(
f"Audio column name '{audio_column_name}' not found in dataset"
f" '{dataset_dict['name']}'. Make sure to set `--audio_column_name` to"
f" the correct audio column - one of {', '.join(dataset_features)}."
) )
if self.freeze_feature_extractor and not self.freeze_feature_encoder:
if dataset_dict["label_column_name"] not in dataset_features:
raise ValueError( raise ValueError(
"The argument `--freeze_feature_extractor` is deprecated and " f"Label column name {dataset_dict['text_column_name']} not found in dataset"
"should not be used in combination with `--freeze_feature_encoder`. " f" '{dataset_dict['name']}'. Make sure to set `--label_column_name` to the"
"Only make use of `--freeze_feature_encoder`." f" correct text column - one of {', '.join(dataset_features)}."
) )
# blanket renaming of all label columns to label
if dataset_dict["label_column_name"] != "label":
dataset = dataset.rename_column(dataset_dict["label_column_name"], "label")
dataset_features = dataset.features.keys()
columns_to_keep = {"audio", "label"}
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset)
if len(all_datasets) == 1:
# we have a single dataset so just return it as is
return all_datasets[0]
if streaming:
interleaved_dataset = interleave_datasets(
all_datasets,
stopping_strategy=stopping_strategy,
probabilities=probabilities,
seed=seed,
)
else:
interleaved_dataset = concatenate_datasets(all_datasets)
return interleaved_dataset
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
...@@ -212,15 +345,6 @@ def main(): ...@@ -212,15 +345,6 @@ def main():
else: else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions. # information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_audio_classification", model_args, data_args) send_example_telemetry("run_audio_classification", model_args, data_args)
...@@ -269,31 +393,80 @@ def main(): ...@@ -269,31 +393,80 @@ def main():
# Initialize our dataset and prepare it for the audio classification task. # Initialize our dataset and prepare it for the audio classification task.
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
raw_datasets["train"] = load_dataset( # set seed for determinism
data_args.dataset_name, set_seed(training_args.seed)
data_args.dataset_config_name,
split=data_args.train_split_name,
token=model_args.token,
)
raw_datasets["eval"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=data_args.eval_split_name,
token=model_args.token,
)
if data_args.audio_column_name not in raw_datasets["train"].column_names: if training_args.do_train:
raise ValueError( raw_datasets["train"] = load_multiple_datasets(
f"--audio_column_name {data_args.audio_column_name} not found in dataset '{data_args.dataset_name}'. " data_args.train_dataset_name,
"Make sure to set `--audio_column_name` to the correct audio column - one of " data_args.train_dataset_config_name,
f"{', '.join(raw_datasets['train'].column_names)}." splits=data_args.train_split_name,
label_column_names=data_args.train_label_column_name,
streaming=data_args.streaming,
dataset_samples=data_args.train_dataset_samples,
seed=training_args.seed,
cache_dir=data_args.dataset_cache_dir,
token=True if model_args.token else None,
trust_remote_code=data_args.trust_remote_code,
) )
if data_args.label_column_name not in raw_datasets["train"].column_names: raw_datasets_train_features = raw_datasets["train"].features.keys()
if training_args.do_eval:
dataset_names_dict = convert_dataset_str_to_list(
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
splits=data_args.eval_split_name,
label_column_names=data_args.eval_label_column_name,
)
all_eval_splits = []
if len(dataset_names_dict) == 1:
# load a single eval set
dataset_dict = dataset_names_dict[0]
all_eval_splits.append("eval")
raw_datasets["eval"] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
token=True if model_args.token else None,
streaming=data_args.streaming,
trust_remote_code=data_args.trust_remote_code,
)
else:
# load multiple eval sets
for dataset_dict in dataset_names_dict:
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
trust_remote_code=data_args.trust_remote_code,
)
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
elif dataset_dict["label_column_name"] != "label":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["label_column_name"], "label"
)
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "label"}
)
if not training_args.do_train and not training_args.do_eval:
raise ValueError( raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. " "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
) )
# Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over # Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
...@@ -325,8 +498,6 @@ def main(): ...@@ -325,8 +498,6 @@ def main():
subsampled_wavs.append(wav) subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate) inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)} output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch return output_batch
def val_transforms(batch): def val_transforms(batch):
...@@ -334,8 +505,6 @@ def main(): ...@@ -334,8 +505,6 @@ def main():
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]] wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate) inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)} output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = list(batch[data_args.label_column_name])
return output_batch return output_batch
# Prepare label mappings. # Prepare label mappings.
...@@ -388,7 +557,9 @@ def main(): ...@@ -388,7 +557,9 @@ def main():
raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
) )
# Set the training transforms # Set the training transforms
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False) raw_datasets["train"].set_transform(
train_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
if training_args.do_eval: if training_args.do_eval:
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
...@@ -396,7 +567,9 @@ def main(): ...@@ -396,7 +567,9 @@ def main():
raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
) )
# Set the validation transforms # Set the validation transforms
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False) raw_datasets["eval"].set_transform(
val_transforms, columns=[model_input_name, "labels"], output_all_columns=False
)
# Initialize our trainer # Initialize our trainer
trainer = Trainer( trainer = Trainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment