Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
# TODO: pretrain from scratch
model_name_or_path:str=field(
metadata={"help":"Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained config name or path if not the same as model_name"}
)
feature_extractor_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained feature extractor name or path if not the same as model_name"}
)
description_tokenizer_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained description tokenizer name or path if not the same as model_name"}
)
prompt_tokenizer_name:Optional[str]=field(
default=None,metadata={"help":"Pretrained prompt tokenizer name or path if not the same as description_tokenizer_name"}
)
cache_dir:Optional[str]=field(
default=None,
metadata={"help":"Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer:bool=field(
default=True,
metadata={"help":"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision:str=field(
default="main",
metadata={"help":"The specific model version to use (can be a branch name, tag name or commit id)."},
)
pad_token_id:int=field(
default=None,
metadata={"help":"If specified, change the model pad token id."},
)
decoder_start_token_id:int=field(
default=None,
metadata={"help":"If specified, change the model decoder start token id."},
)
freeze_text_encoder:bool=field(
default=False,
metadata={"help":"Whether to freeze the text encoder."},
)
@dataclass
classDataSeq2SeqTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""
train_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
train_dataset_config_name:Optional[str]=field(
default=None,
metadata={
"help":"The configuration name of the training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset configs by a '+' symbol."
},
)
train_split_name:str=field(
default="train",
metadata={
"help":("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
},
)
train_dataset_samples:str=field(
default=None,
metadata={
"help":"Number of samples in the training data. Load and combine "
"multiple datasets by separating dataset samples by a '+' symbol."
},
)
train_metadata_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the metadata training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
eval_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset name if unspecified."
},
)
eval_dataset_config_name:Optional[str]=field(
default=None,
metadata={
"help":"The configuration name of the evaluation dataset to use (via the datasets library). Defaults to the training dataset config name if unspecified"
},
)
eval_split_name:str=field(
default="test",
metadata={
"help":"The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
},
)
eval_metadata_dataset_name:str=field(
default=None,
metadata={
"help":"The name of the metadata training dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset ids by a '+' symbol. For example, to load and combine "
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
target_audio_column_name:str=field(# TODO
default="audio",
metadata={"help":"The name of the dataset column containing the target audio data. Defaults to 'audio'"},
)
conditional_audio_column_name:str=field(# TODO
default=None,
metadata={"help":"The name of the dataset column containing the conditional audio data. Defaults to 'audio'"},
)
description_column_name:str=field(#TODO
default=None,
metadata={"help":"The name of the dataset column containing the text data. Defaults to 'None'."},
)
prompt_column_name:str=field(#TODO
default=None,
metadata={"help":"The name of the dataset column containing the text data. Defaults to 'None'."},
)
overwrite_cache:bool=field(
default=False,metadata={"help":"Overwrite the cached preprocessed datasets or not."}
)
preprocessing_num_workers:Optional[int]=field(
default=None,
metadata={"help":"The number of processes to use for the preprocessing."},
)
max_train_samples:Optional[int]=field(
default=None,
metadata={
"help":(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples:Optional[int]=field(
default=None,
metadata={
"help":(
"For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
)
},
)
max_duration_in_seconds:float=field(
default=35.0,
metadata={
"help":(
"Filter audio files that are longer than `max_duration_in_seconds` seconds to"
" 'max_duration_in_seconds`"
)
},
)
min_duration_in_seconds:float=field(
default=0.0,metadata={"help":"Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
)
preprocessing_only:bool=field(
default=False,
metadata={
"help":(
"Whether to only do data preprocessing and skip training. This is especially useful when data"
" preprocessing errors out in distributed training due to timeout. In this case, one should run the"
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
" can consequently be loaded in distributed training"
)
},
)
token:str=field(
default=None,
metadata={
"help":(
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token:bool=field(
default=None,
metadata={
"help":"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code:bool=field(
default=False,
metadata={
"help":(
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
add_audio_samples_to_wandb:bool=field(
default=False,
metadata={
"help":"If set and if `wandb` in args.report_to, will add generated audio samples to wandb logs."
}
)
id_column_name:str=field(
default=None,
metadata={
"help":"id column name."
}
)
@dataclass
classDataCollatorMusicGenWithPadding:
"""
Data collator that will dynamically pad the inputs received.
logger.warning("Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer.")