Commit 94f40c57 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

working audio class

parent 7cbf4d55
......@@ -136,13 +136,13 @@ class DataTrainingArguments:
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
train_label_column_name: str = field(
default="label",
default="labels",
metadata={
"help": "The name of the dataset column containing the labels in the train set. Defaults to 'label'"
},
)
eval_label_column_name: str = field(
default="label",
default="labels",
metadata={"help": "The name of the dataset column containing the labels in the eval set. Defaults to 'label'"},
)
max_train_samples: Optional[int] = field(
......@@ -275,7 +275,7 @@ def convert_dataset_str_to_list(
dataset_samples = [None] * len(dataset_names)
label_column_names = (
label_column_names if label_column_names is not None else ["label" for _ in range(len(dataset_names))]
label_column_names if label_column_names is not None else ["labels" for _ in range(len(dataset_names))]
)
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
......@@ -300,7 +300,7 @@ def load_multiple_datasets(
label_column_names: Optional[List] = None,
stopping_strategy: Optional[str] = "first_exhausted",
dataset_samples: Optional[Union[List, np.array]] = None,
streaming: Optional[bool] = True,
streaming: Optional[bool] = False,
seed: Optional[int] = None,
audio_column_name: Optional[str] = "audio",
**kwargs,
......@@ -342,11 +342,11 @@ def load_multiple_datasets(
)
# blanket renaming of all label columns to label
if dataset_dict["label_column_name"] != "label":
dataset = dataset.rename_column(dataset_dict["label_column_name"], "label")
if dataset_dict["label_column_name"] != "labels":
dataset = dataset.rename_column(dataset_dict["label_column_name"], "labels")
dataset_features = dataset.features.keys()
columns_to_keep = {"audio", "label"}
columns_to_keep = {"audio", "labels"}
dataset = dataset.remove_columns(set(dataset_features - columns_to_keep))
all_datasets.append(dataset)
......@@ -451,11 +451,15 @@ def main():
label_column_names=data_args.eval_label_column_name,
)
all_eval_splits = []
if len(dataset_names_dict) == 1:
# load a single eval set
dataset_dict = dataset_names_dict[0]
all_eval_splits.append("eval")
raw_datasets["eval"] = load_dataset(
# load multiple eval sets
for dataset_dict in dataset_names_dict:
pretty_name = (
f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
if len(dataset_names_dict) > 1
else "eval"
)
all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
......@@ -464,34 +468,20 @@ def main():
trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming,
)
else:
# load multiple eval sets
for dataset_dict in dataset_names_dict:
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
all_eval_splits.append(pretty_name)
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=model_args.cache_dir,
token=True if model_args.use_auth_token else None,
trust_remote_code=model_args.trust_remote_code,
# streaming=data_args.streaming,
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {data_args.eval_label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
features = raw_datasets[pretty_name].features.keys()
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {data_args.eval_label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(raw_datasets['train'].column_names)}."
)
elif dataset_dict["label_column_name"] != "label":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["label_column_name"], "label"
)
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "label"}
elif dataset_dict["label_column_name"] != "labels":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["label_column_name"], "labels"
)
raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
set(raw_datasets[pretty_name].features.keys()) - {"audio", "labels"}
)
if not training_args.do_train and not training_args.do_eval:
raise ValueError(
......@@ -529,56 +519,70 @@ def main():
sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0]
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate
def prepare_dataset(sample):
audio = sample["audio"]["array"]
if len(audio) / sampling_rate > max_input_length:
audio = random_subsample(audio, max_input_length, sampling_rate)
inputs = feature_extractor(audio, sampling_rate=sampling_rate)
sample[model_input_name] = inputs.get(model_input_name)
sample["input_length"] = len(audio) / sampling_rate
sample["labels"] = preprocess_labels(sample["labels"])
return sample
vectorized_datasets = raw_datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preprocess dataset"
)
# filter training data with inputs longer than max_input_length
def is_audio_in_length_range(length):
return min_input_length < length < max_input_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
input_columns=["input_length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# filter training data with non valid labels
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
vectorized_datasets = vectorized_datasets.filter(
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
)
# Prepare label mappings.
# Prepare label mappings
raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)},
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels",
)
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = vectorized_datasets["train"]["labels"]
label2id, id2label, num_label = {}, {}, {}
for i, label in enumerate(labels):
num_label[label] += 1
if label not in label2id:
label2id[label] = str(i)
id2label[str(i)] = label
set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {}
for i, label in enumerate(set(set_labels)):
label2id[label] = str(i)
id2label[str(i)] = label
train_labels = raw_datasets["train"]["labels"]
num_labels = {key: 0 for key in set(train_labels)}
for label in train_labels:
num_labels[label] += 1
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
num_labels = sorted(num_labels.items(), key=lambda x: (-x[1], x[0]))
logger.info(f"{'Language':<15} {'Count':<5}")
logger.info("-" * 20)
for language, count in num_labels:
logger.info(f"{language:<15} {count:<5}")
def train_transforms(batch):
"""Apply train_transforms across a batch."""
subsampled_wavs = []
for audio in batch["audio"]:
wav = random_subsample(audio["array"], max_length=data_args.max_length_seconds, sample_rate=sampling_rate)
subsampled_wavs.append(wav)
inputs = feature_extractor(subsampled_wavs, sampling_rate=sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
return output_batch
def val_transforms(batch):
"""Apply val_transforms across a batch."""
wavs = [audio["array"] for audio in batch["audio"]]
inputs = feature_extractor(wavs, sampling_rate=sampling_rate)
output_batch = {model_input_name: inputs.get(model_input_name)}
output_batch["labels"] = [int(label2id[label]) for label in batch["labels"]]
return output_batch
logger.info(f"Number of labels: {num_label}")
if training_args.do_train:
# Set the training transforms
raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
if training_args.do_eval:
# Set the validation transforms
raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
......@@ -592,7 +596,7 @@ def main():
config = AutoConfig.from_pretrained(
model_args.config_name or model_args.model_name_or_path,
num_labels=len(labels),
num_labels=len(label2id),
label2id=label2id,
id2label=id2label,
finetuning_task="audio-classification",
......@@ -620,8 +624,8 @@ def main():
trainer = Trainer(
model=model,
args=training_args,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
train_dataset=raw_datasets["train"] if training_args.do_train else None,
eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
)
......@@ -649,7 +653,7 @@ def main():
kwargs = {
"finetuned_from": model_args.model_name_or_path,
"tasks": "audio-classification",
"dataset": data_args.dataset_name,
"dataset": data_args.train_dataset_name.split("+")[0],
"tags": ["audio-classification"],
}
if training_args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment