Commit 94f40c57 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

working audio class

parent 7cbf4d55
...@@ -136,13 +136,13 @@ class DataTrainingArguments: ...@@ -136,13 +136,13 @@ class DataTrainingArguments:
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
) )
train_label_column_name: str = field( train_label_column_name: str = field(
default="label", default="labels",
metadata={ metadata={
"help": "The name of the dataset column containing the labels in the train set. Defaults to 'label'" "help": "The name of the dataset column containing the labels in the train set. Defaults to 'label'"
}, },
) )
eval_label_column_name: str = field( eval_label_column_name: str = field(
default="label", default="labels",
metadata={"help": "The name of the dataset column containing the labels in the eval set. Defaults to 'label'"}, metadata={"help": "The name of the dataset column containing the labels in the eval set. Defaults to 'label'"},
) )
max_train_samples: Optional[int] = field( max_train_samples: Optional[int] = field(
...@@ -275,7 +275,7 @@ def convert_dataset_str_to_list( ...@@ -275,7 +275,7 @@ def convert_dataset_str_to_list(
dataset_samples = [None] * len(dataset_names) dataset_samples = [None] * len(dataset_names)
label_column_names = ( label_column_names = (
label_column_names if label_column_names is not None else ["label" for _ in range(len(dataset_names))] label_column_names if label_column_names is not None else ["labels" for _ in range(len(dataset_names))]
) )
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))] splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
...@@ -300,7 +300,7 @@ def load_multiple_datasets( ...@@ -300,7 +300,7 @@ def load_multiple_datasets(
label_column_names: Optional[List] = None, label_column_names: Optional[List] = None,
stopping_strategy: Optional[str] = "first_exhausted", stopping_strategy: Optional[str] = "first_exhausted",
dataset_samples: Optional[Union[List, np.array]] = None, dataset_samples: Optional[Union[List, np.array]] = None,
streaming: Optional[bool] = True, streaming: Optional[bool] = False,
seed: Optional[int] = None, seed: Optional[int] = None,
audio_column_name: Optional[str] = "audio", audio_column_name: Optional[str] = "audio",
**kwargs, **kwargs,
...@@ -342,11 +342,11 @@ def load_multiple_datasets( ...@@ -342,11 +342,11 @@ def load_multiple_datasets(
) )
# blanket renaming of all label columns to label # blanket renaming of all label columns to label
if dataset_dict["label_column_name"] != "label": if dataset_dict["label_column_name"] != "labels":
dataset = dataset.rename_column(dataset_dict["label_column_name"], "label") dataset = dataset.rename_column(dataset_dict["label_column_name"], "labels")
dataset_features = dataset.features.keys() dataset_features = dataset.features.keys()
columns_to_keep = {"audio", "label"} columns_to_keep = {"audio", "labels"}
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep)) dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset) all_datasets.append(dataset)
...@@ -451,30 +451,20 @@ def main(): ...@@ -451,30 +451,20 @@ def main():
label_column_names=data_args.eval_label_column_name, label_column_names=data_args.eval_label_column_name,
) )
all_eval_splits = [] all_eval_splits = []
if len(dataset_names_dict) == 1:
# load a single eval set
dataset_dict = dataset_names_dict[0]
all_eval_splits.append("eval")
raw_datasets["eval"] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=model_args.cache_dir,
token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming,
)
else:
# load multiple eval sets # load multiple eval sets
for dataset_dict in dataset_names_dict: for dataset_dict in dataset_names_dict:
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}" pretty_name = (
f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
if len(dataset_names_dict) > 1
else "eval"
)
all_eval_splits.append(pretty_name) all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset( raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"], dataset_dict["name"],
dataset_dict["config"], dataset_dict["config"],
split=dataset_dict["split"], split=dataset_dict["split"],
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=True if model_args.use_auth_token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming, # streaming=data_args.streaming,
) )
...@@ -485,12 +475,12 @@ def main(): ...@@ -485,12 +475,12 @@ def main():
"Make sure to set `--label_column_name` to the correct text column - one of " "Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}." f"{', '.join(raw_datasets['train'].column_names)}."
) )
elif dataset_dict["label_column_name"] != "label": elif dataset_dict["label_column_name"] != "labels":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column( raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["label_column_name"], "label" dataset_dict["label_column_name"], "labels"
) )
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns( raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "label"} set(raw_datasets[pretty_name].features.keys()) - {"audio", "labels"}
) )
if not training_args.do_train and not training_args.do_eval: if not training_args.do_train and not training_args.do_eval:
...@@ -529,56 +519,70 @@ def main(): ...@@ -529,56 +519,70 @@ def main():
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0] model_input_name = feature_extractor.model_input_names[0]
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate
def prepare_dataset(sample):
audio = sample["audio"]["array"]
if len(audio) / sampling_rate > max_input_length:
audio = random_subsample(audio, max_input_length, sampling_rate)
inputs = feature_extractor(audio, sampling_rate=sampling_rate)
sample[model_input_name] = inputs.get(model_input_name)
sample["input_length"] = len(audio) / sampling_rate
sample["labels"] = preprocess_labels(sample["labels"])
return sample
vectorized_datasets = raw_datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preprocess dataset"
)
# filter training data with inputs longer than max_input_length
def is_audio_in_length_range(length):
return min_input_length < length < max_input_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
input_columns=["input_length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non valid labels # filter training data with non-valid labels
def is_label_valid(label): def is_label_valid(label):
return label != "Unknown" return label != "Unknown"
vectorized_datasets = vectorized_datasets.filter( raw_datasets = raw_datasets.filter(
is_label_valid, is_label_valid,
input_columns=["labels"], input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels", desc="Filtering by labels",
) )
# Prepare label mappings. # Prepare label mappings
raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)},
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels",
)
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = vectorized_datasets["train"]["labels"] set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"]))
label2id, id2label, num_label = {}, {}, {} label2id, id2label = {}, {}
for i, label in enumerate(labels): for i, label in enumerate(set(set_labels)):
num_label[label] += 1
if label not in label2id:
label2id[label] = str(i) label2id[label] = str(i)
id2label[str(i)] = label id2label[str(i)] = label
logger.info(f"Number of labels: {num_label}") train_labels = raw_datasets["train"]["labels"]
num_labels = {key: 0 for key in set(train_labels)}
for label in train_labels:
num_labels[label] += 1
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
num_labels = sorted(num_labels.items(), key=lambda x: (-x[1], x[0]))
logger.info(f"{'Language':<15} {'Count':<5}")
logger.info("-" * 20)
for language, count in num_labels:
logger.info(f"{language:<15} {count:<5}")
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
for audio in batch["audio"]:
wav = random_subsample(audio["array"], max_length=data_args.max_length_seconds, sample_rate=sampling_rate)
subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
return output_batch
def val_transforms(batch):
"""Apply val_transforms across a batch."""
wavs = [audio["array"] for audio in batch["audio"]]
inputs = feature_extractor(wavs, sampling_rate=sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
return output_batch
if training_args.do_train:
# Set the training transforms
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
if training_args.do_eval:
# Set the validation transforms
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
# Load the accuracy metric from the datasets package # Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir) metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
...@@ -592,7 +596,7 @@ def main(): ...@@ -592,7 +596,7 @@ def main():
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_args.config_name or model_args.model_name_or_path, model_args.config_name or model_args.model_name_or_path,
num_labels=len(labels), num_labels=len(label2id),
label2id=label2id, label2id=label2id,
id2label=id2label, id2label=id2label,
finetuning_task="audio-classification", finetuning_task="audio-classification",
...@@ -620,8 +624,8 @@ def main(): ...@@ -620,8 +624,8 @@ def main():
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None, train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=feature_extractor, tokenizer=feature_extractor,
) )
...@@ -649,7 +653,7 @@ def main(): ...@@ -649,7 +653,7 @@ def main():
kwargs = { kwargs = {
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"tasks": "audio-classification", "tasks": "audio-classification",
"dataset": data_args.dataset_name, "dataset": data_args.train_dataset_name.split("+")[0],
"tags": ["audio-classification"], "tags": ["audio-classification"],
} }
if training_args.push_to_hub: if training_args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment