Commit 9d25447e authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

fix eval + correct json

parent 0f6d59d4
{ {
"model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/", "model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/",
"feature_extractor_name":"facebook/encodec_24khz", "feature_extractor_name":"facebook/encodec_32khz",
"description_tokenizer_name":"t5-base", "description_tokenizer_name":"t5-base",
"prompt_tokenizer_name":"t5-base", "prompt_tokenizer_name":"t5-base",
......
{ {
"model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/", "model_name_or_path": "/home/yoach/dataspeech/artefacts/tiny-model/",
"feature_extractor_name":"facebook/encodec_24khz", "feature_extractor_name":"facebook/encodec_32khz",
"description_tokenizer_name":"t5-base", "description_tokenizer_name":"t5-base",
"prompt_tokenizer_name":"t5-base", "prompt_tokenizer_name":"t5-base",
......
...@@ -31,7 +31,7 @@ import evaluate ...@@ -31,7 +31,7 @@ import evaluate
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union, Set
import datasets import datasets
import numpy as np import numpy as np
...@@ -606,7 +606,7 @@ def load_multiple_datasets( ...@@ -606,7 +606,7 @@ def load_multiple_datasets(
streaming: Optional[bool] = False, streaming: Optional[bool] = False,
seed: Optional[int] = None, seed: Optional[int] = None,
id_column_name: Optional[str] = None, id_column_name: Optional[str] = None,
columns_to_keep: Optional[set[str]] = None, columns_to_keep: Optional[Set[str]] = None,
**kwargs, **kwargs,
) -> Union[Dataset, IterableDataset]: ) -> Union[Dataset, IterableDataset]:
dataset_names_dict = convert_dataset_str_to_list( dataset_names_dict = convert_dataset_str_to_list(
...@@ -1396,9 +1396,8 @@ def main(): ...@@ -1396,9 +1396,8 @@ def main():
# Gather all predictions and targets # Gather all predictions and targets
# TODO: also add prompt ids # TODO: also add prompt ids
# TODO: better gather # TODO: better gather
generated_audios, input_ids, prompts = accelerator.gather_for_metrics( generated_audios, input_ids, prompts = accelerator.pad_across_processes((generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0)
(generated_audios, batch["input_ids"], batch["prompt_input_ids"]) generated_audios, input_ids, prompts =accelerator.gather_for_metrics((generated_audios, input_ids, prompts))
)
eval_preds.extend(generated_audios) eval_preds.extend(generated_audios)
eval_descriptions.extend(input_ids) eval_descriptions.extend(input_ids)
eval_prompts.extend(prompts) eval_prompts.extend(prompts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment