Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path:str=field(
metadata={"help":"Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained config name or path if not the same as model_name"}
)
feature_extractor_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained feature extractor name or path if not the same as model_name"}
)
description_tokenizer_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained description tokenizer name or path if not the same as model_name"}
)
prompt_tokenizer_name:Optional[str]=field(
default=None,
metadata={"help":"Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
)
cache_dir:Optional[str]=field(
default=None,
metadata={"help":"Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer:bool=field(
default=True,
metadata={"help":"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision:str=field(
default="main",
metadata={"help":"The specific model version to use (can be a branch name, tag name or commit id)."},
)
pad_token_id:int=field(
default=None,
metadata={"help":"If specified, change the model pad token id."},
)
decoder_start_token_id:int=field(
default=None,
metadata={"help":"If specified, change the model decoder start token id."},
)
freeze_text_encoder:bool=field(
default=False,
metadata={"help":"Whether to freeze the text encoder."},
)
do_sample:bool=field(
default=True,
metadata={"help":"Whether to do sampling or greedy decoding."},
)
temperature:float=field(
default=1.0,
metadata={"help":"Temperature if sampling."},
)
max_length:int=field(
default=2580,
metadata={"help":"Generation max length."},
)
bandwidth:float=field(
default=6,
metadata={"help":"Audio encoder bandwidth."},
)
asr_model_name_or_path:str=field(
default="distil-whisper/distil-large-v2",
metadata={"help":"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
clap_model_name_or_path:str=field(
default="laion/larger_clap_music_and_speech",
metadata={"help":"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
@dataclass
classDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
train_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
train_dataset_config_name:Optional[str]=field(
default=None,
metadata={
"help":"The configuration name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset configs by a '+' symbol."
},
)
train_split_name:str=field(
default="train",
metadata={
"help":("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
},
)
train_dataset_samples:str=field(
default=None,
metadata={
"help":"Number of samples in the training data. Load and combine "
"multiple datasets by separating dataset samples by a '+' symbol."
},
)
train_metadata_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the metadata training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
eval_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
},
)
eval_dataset_config_name:Optional[str]=field(
default=None,
metadata={
"help":"The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
},
)
eval_split_name:str=field(
default="test",
metadata={
"help":"The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
},
)
eval_metadata_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the metadata training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
target_audio_column_name:str=field(
default="audio",
metadata={"help":"The name of the dataset column containing the target audio data. Defaults to 'audio'"},
)
description_column_name:str=field(
default=None,
metadata={"help":"The name of the dataset column containing the description text data. Defaults to 'None'."},
)
prompt_column_name:str=field(
default=None,
metadata={"help":"The name of the dataset column containing the prompt text data. Defaults to 'None'."},
)
overwrite_cache:bool=field(
default=False,metadata={"help":"Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers:Optional[int]=field(
default=None,
metadata={"help":"The number of processes to use for the preprocessing."},
)
max_train_samples:Optional[int]=field(
default=None,
metadata={
"help":(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples:Optional[int]=field(
default=None,
metadata={
"help":(
"For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
)
},
)
max_duration_in_seconds:float=field(
default=35.0,
metadata={
"help":(
"Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
"Also, used to set maximum audio length if `pad_to_max_length=True`."
)
},
)
min_duration_in_seconds:float=field(
default=0.0,metadata={"help":"Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
)
max_text_length:int=field(
default=500,metadata={"help":"If set, max description lengths in number of characters."}
)
max_prompt_token_length:int=field(
default=None,
metadata={
"help":(
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
)
},
)
max_description_token_length:int=field(
default=None,
metadata={
"help":(
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
)
},
)
pad_to_max_length:bool=field(
default=False,
metadata={
"help":(
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path:str=field(
metadata={"help":"Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained config name or path if not the same as model_name"}
)
feature_extractor_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained feature extractor name or path if not the same as model_name"}
)
description_tokenizer_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained description tokenizer name or path if not the same as model_name"}
)
prompt_tokenizer_name:Optional[str]=field(
default=None,
metadata={"help":"Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"},
)
cache_dir:Optional[str]=field(
default=None,
metadata={"help":"Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer:bool=field(
default=True,
metadata={"help":"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision:str=field(
default="main",
metadata={"help":"The specific model version to use (can be a branch name, tag name or commit id)."},
)
pad_token_id:int=field(
default=None,
metadata={"help":"If specified, change the model pad token id."},
)
decoder_start_token_id:int=field(
default=None,
metadata={"help":"If specified, change the model decoder start token id."},
)
freeze_text_encoder:bool=field(
default=False,
metadata={"help":"Whether to freeze the text encoder."},
)
do_sample:bool=field(
default=True,
metadata={"help":"Whether to do sampling or greedy decoding."},
)
temperature:float=field(
default=1.0,
metadata={"help":"Temperature if sampling."},
)
max_length:int=field(
default=2580,
metadata={"help":"Generation max length."},
)
bandwidth:float=field(
default=6,
metadata={"help":"Audio encoder bandwidth."},
)
asr_model_name_or_path:str=field(
default="distil-whisper/distil-large-v2",
metadata={"help":"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
clap_model_name_or_path:str=field(
default="laion/larger_clap_music_and_speech",
metadata={"help":"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"}
)
@dataclass
classDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
train_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
train_dataset_config_name:Optional[str]=field(
default=None,
metadata={
"help":"The configuration name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset configs by a '+' symbol."
},
)
train_split_name:str=field(
default="train",
metadata={
"help":("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
},
)
train_dataset_samples:str=field(
default=None,
metadata={
"help":"Number of samples in the training data. Load and combine "
"multiple datasets by separating dataset samples by a '+' symbol."
},
)
train_metadata_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the metadata training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
eval_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
},
)
eval_dataset_config_name:Optional[str]=field(
default=None,
metadata={
"help":"The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
},
)
eval_split_name:str=field(
default="test",
metadata={
"help":"The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
},
)
eval_metadata_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the metadata training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
target_audio_column_name:str=field(
default="audio",
metadata={"help":"The name of the dataset column containing the target audio data. Defaults to 'audio'"},
)
description_column_name:str=field(
default=None,
metadata={"help":"The name of the dataset column containing the description text data. Defaults to 'None'."},
)
prompt_column_name:str=field(
default=None,
metadata={"help":"The name of the dataset column containing the prompt text data. Defaults to 'None'."},
)
overwrite_cache:bool=field(
default=False,metadata={"help":"Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers:Optional[int]=field(
default=None,
metadata={"help":"The number of processes to use for the preprocessing."},
)
max_train_samples:Optional[int]=field(
default=None,
metadata={
"help":(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples:Optional[int]=field(
default=None,
metadata={
"help":(
"For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
)
},
)
max_duration_in_seconds:float=field(
default=35.0,
metadata={
"help":(
"Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`."
"Also, used to set maximum audio length if `pad_to_max_length=True`."
)
},
)
min_duration_in_seconds:float=field(
default=0.0,metadata={"help":"Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
)
max_text_length:int=field(
default=500,metadata={"help":"If set, max description lengths in number of characters."}
)
max_prompt_token_length:int=field(
default=None,
metadata={
"help":(
"If set, filter samples with prompts that are longer than `max_prompt_token_length` tokens."
"Also, used to set maximum prompt token length if `pad_to_max_length=True`."
)
},
)
max_description_token_length:int=field(
default=None,
metadata={
"help":(
"If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
"Also, used to set maximum desription token length if `pad_to_max_length=True`."
)
},
)
pad_to_max_length:bool=field(
default=False,
metadata={
"help":(
"If `True`, pad audio, prompt and description to a maximum length set with respectively "
"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."