"test/vscode:/vscode.git/clone" did not exist on "4260b3460f339292b9858db009111fc4965074a6"
Unverified Commit c6b4674d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Merge branch 'sanchit-gandhi:main' into add-training

parents 3754ce00 c5de50bd
command:
- python3
- ${program}
- --fp16
- --fp16_full_eval
- --do_train
- --do_eval
- --trust_remote_code
- --overwrite_output_dir
- --ignore_mismatched_sizes
- --gradient_checkpointing
- ${args}
method: random
metric:
goal: maximize
name: eval/accuracy
parameters:
model_name_or_path:
value: facebook/mms-lid-126
train_dataset_name:
value: stable-speech/concatenated-normalized-accent-dataset
train_dataset_config_name:
value: default
train_split_name:
value: train
train_label_column_name:
value: labels
eval_dataset_name:
value: stable-speech/concatenated-normalized-accent-dataset
eval_dataset_config_name:
value: default
eval_split_name:
value: test
eval_label_column_name:
value: labels
output_dir:
value: ./
remove_unused_columns:
value: false
learning_rate:
value: 1e-4
lr_scheduler_type:
value: constant_with_warmup
max_length_seconds:
value: 20
min_length_seconds:
value: 5
attention_mask:
value: true
warmup_steps:
value: 50
max_steps:
value: 1000
per_device_train_batch_size:
value: 32
per_device_eval_batch_size:
value: 32
preprocessing_num_workers:
value: 4
dataloader_num_workers:
value: 4
logging_strategy:
value: steps
logging_steps:
value: 10
evaluation_strategy:
value: steps
eval_steps:
value: 1000
save_strategy:
value: steps
save_steps:
value: 1000
freeze_base_model:
values:
- false
- true
push_to_hub:
value: false
filter_threshold:
value: 1
feat_proj_dropout:
values:
- 0.0
- 0.1
- 0.2
attention_dropout:
values:
- 0.0
- 0.1
- 0.2
activation_dropout:
values:
- 0.0
- 0.1
- 0.2
hidden_dropout:
values:
- 0.0
- 0.1
- 0.2
final_dropout:
values:
- 0.0
- 0.1
- 0.2
mask_time_prob:
values:
- 0.0
- 0.1
- 0.2
mask_time_length:
values:
- 10
- 15
- 20
mask_feature_prob:
values:
- 0.0
- 0.1
- 0.2
mask_feature_length:
values:
- 10
- 15
- 20
program: run_audio_classification.py
project: mms-lid-accent-classification
\ No newline at end of file
......@@ -2,36 +2,37 @@
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc" \
--train_dataset_config_name "default+en_accented+default" \
--train_split_name "train+test+validation" \
--train_label_column_name "accent+accent+accent" \
--eval_dataset_name "sanchit-gandhi/edacc" \
--train_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--train_dataset_config_name "default" \
--train_split_name "train" \
--train_label_column_name "labels" \
--eval_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "accent" \
--eval_label_column_name "labels" \
--output_dir "./" \
--do_train \
--do_eval \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--fp16_full_eval \
--learning_rate 1e-4 \
--max_length_seconds 20 \
--attention_mask False \
--warmup_ratio 0.1 \
--num_train_epochs 5 \
--min_length_seconds 5 \
--attention_mask \
--warmup_steps 100 \
--max_steps 2000 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--preprocessing_num_workers 16 \
--preprocessing_num_workers 4 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--load_best_model_at_end True \
--metric_for_best_model "accuracy" \
--save_total_limit 3 \
--freeze_base_model \
--push_to_hub \
--evaluation_strategy "steps" \
--eval_steps 500 \
--save_strategy "no" \
--save_steps 2000 \
--freeze_base_model True \
--push_to_hub False \
--trust_remote_code
#!/usr/bin/env bash
python run_audio_classification.py \
--model_name_or_path "facebook/mms-lid-126" \
--train_dataset_name "stable-speech/concatenated-normalized-accent-dataset+stable-speech/concatenated-common-voice-15-accented" \
--train_dataset_config_name "default+default" \
--train_split_name "train+train" \
--train_label_column_name "labels+labels" \
--eval_dataset_name "stable-speech/concatenated-normalized-accent-dataset" \
--eval_dataset_config_name "default" \
--eval_split_name "test" \
--eval_label_column_name "labels" \
--output_dir "./" \
--do_train \
--do_eval \
--overwrite_output_dir \
--remove_unused_columns False \
--fp16 \
--fp16_full_eval \
--learning_rate 1e-4 \
--lr_scheduler_type "constant_with_warmup" \
--max_length_seconds 20 \
--min_length_seconds 5 \
--attention_mask \
--warmup_steps 100 \
--max_steps 5000 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--preprocessing_num_workers 4 \
--dataloader_num_workers 4 \
--logging_strategy "steps" \
--logging_steps 10 \
--evaluation_strategy "steps" \
--eval_steps 1000 \
--save_strategy "no" \
--save_steps 5000 \
--filter_threshold 0.01 \
--freeze_base_model False \
--gradient_checkpointing \
--push_to_hub False \
--trust_remote_code
......@@ -17,25 +17,23 @@ metric:
name: eval/accuracy
parameters:
model_name_or_path:
values:
- facebook/mms-lid-126
- openai/whisper-large-v3
value: facebook/mms-lid-126
train_dataset_name:
value: sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc
value: stable-speech/concatenated-accent-dataset
train_dataset_config_name:
value: default+en_accented+default
value: default
train_split_name:
value: train+test+validation
value: train
train_label_column_name:
value: accent+accent+accent
value: labels
eval_dataset_name:
value: sanchit-gandhi/edacc
value: stable-speech/concatenated-accent-dataset
eval_dataset_config_name:
value: default
eval_split_name:
value: test
eval_label_column_name:
value: accent
value: labels
output_dir:
value: ./
remove_unused_columns:
......@@ -45,13 +43,13 @@ parameters:
lr_scheduler_type:
value: constant_with_warmup
max_length_seconds:
value: 10 # give some data diversity for longer audio samples
value: 20 # give some data diversity for longer audio samples
min_length_seconds:
value: 5
value: 7
attention_mask:
value: false
value: true
warmup_steps:
value: 50
value: 100
max_steps:
value: 2000
per_device_train_batch_size:
......@@ -59,7 +57,7 @@ parameters:
per_device_eval_batch_size:
value: 16
preprocessing_num_workers:
value: 16
value: 4
dataloader_num_workers:
value: 4
logging_strategy:
......@@ -69,7 +67,7 @@ parameters:
evaluation_strategy:
value: steps
eval_steps:
value: 2000
value: 1000
save_strategy:
value: steps
save_steps:
......@@ -77,7 +75,11 @@ parameters:
metric_for_best_model:
value: accuracy
freeze_base_model:
value: false
values:
- false
- true
group_by_length:
value: false # TODO(SG): batch by length
push_to_hub:
value: false
program: run_audio_classification.py
......
#!/usr/bin/env bash
python run_dataset_concatenation.py \
--dataset_name "sanchit-gandhi/vctk+facebook/voxpopuli+sanchit-gandhi/edacc-normalized" \
--dataset_config_name "default+en_accented+default" \
--dataset_split_name "train+test+validation" \
--label_column_name "accent+accent+accent" \
--text_column_name "text+normalized_text+text" \
--speaker_column_name "speaker_id+speaker_id+speaker" \
--batch_size 500 \
--output_dir "./concatenated-dataset"
python run_dataset_concatenation.py \
--dataset_name "sanchit-gandhi/edacc-normalized" \
--dataset_config_name "default" \
--dataset_split_name "test" \
--label_column_name "accent" \
--text_column_name "text" \
--speaker_column_name "speaker" \
--batch_size 500 \
--output_dir "./concatenated-dataset-test"
#!/usr/bin/env bash
python run_dataset_concatenation.py \
--dataset_name "stable-speech/common_voice_15_0_accented" \
--dataset_config_name "en" \
--dataset_split_name "train" \
--label_column_name "accent" \
--text_column_name "sentence" \
--speaker_column_name "client_id" \
--batch_size 250 \
--preprocessing_num_workers 4 \
--output_dir "./concatenated-dataset-cv"
python run_dataset_concatenation.py \
--dataset_name "stable-speech/common_voice_15_0_accented" \
--dataset_config_name "en" \
--dataset_split_name "test" \
--label_column_name "accent" \
--text_column_name "sentence" \
--speaker_column_name "client_id" \
--batch_size 250 \
--preprocessing_num_workers 4 \
--output_dir "./concatenated-dataset-cv-test"
......@@ -6,7 +6,7 @@ import sys
from dataclasses import dataclass, field
import soundfile as sf
from datasets import Audio, Dataset, DatasetDict
from datasets import Audio, Dataset, DatasetDict, load_dataset
from tqdm import tqdm
from transformers import HfArgumentParser
......@@ -53,88 +53,6 @@ class DataTrainingArguments:
)
ACCENT_MAPPING = {
"Italian": "Italian",
"International": "Unknown",
"American": "American",
"English": "English",
"Latin American": "Latin American",
"British": "English",
"Romanian": "Romanian",
"Standard Indian English": "Indian",
"Trans-Atlantic": "Irish",
"Slightly American": "American",
"European": "Unknown",
"Scottish (Fife)": "Scottish",
"English with Scottish inflections": "Scottish",
"Indian": "Indian",
"Asian": "Asian",
"NA": "Unknown",
"German": "German",
"South London": "English",
"Dutch": "Dutch",
"Mostly West Coast American with some Australian Intonation": "American",
"Japanese": "Japanese",
"Chinese": "Chinese",
"Generic middle class white person": "English",
"French": "French",
"Chinese accent or mixed accent(US, UK, China..) perhaps": "Chinese",
"American accent": "American",
"Catalan": "Catalan",
"American, I guess.": "American",
"Spanish American": "Latin American",
"Spanish": "Spanish",
"Standard American,Scottish": "American",
"Bulgarian": "Bulgarian",
"Latin": "Latin American",
"Latín American": "Latin American",
"Mexican": "Latin American", # TODO: un-generalise latin american accents?
"North American": "American",
"Afrian": "African",
"Nigerian": "African", # TODO: un-generalise african accents?
"East-European": "Eastern European",
"Eastern European": "Eastern European",
"Southern London": "English",
"American with a slight accent": "American",
"American-ish": "American",
"Indian / Pakistani accent": "Indian",
"Pakistani/American": "Pakistani",
"African accent": "African",
"Kenyan": "African", # TODO: un-generalise african accents?
"Ghanaian": "African", # TODO: un-generalise african accents?
"Spanish accent": "Spanish",
"Lithuanian": "Lithuanian",
"Lithuanian (eastern European)": "Lithuanian",
"Indonesian": "Indonesian",
"Egyptian": "Egyptian",
"South African English": "South African",
"Neutral": "English",
"Neutral accent": "English",
"Neutral English, Italian": "English",
"Fluent": "Unknown",
"Glaswegian": "Scottish",
"Glaswegian (not slang)": "Scottish",
"Irish": "Irish",
"Jamaican": "Jamaican",
"Jamaican accent": "Jamaican",
"Irish/ Dublin": "Irish",
"South Dublin Irish": "Irish",
"italian": "Italian",
"italian mixed with American and British English": "Italian",
"Italian mixed with American accent": "Italian",
"South American": "Latin American",
"Brazilian accent": "Latin American", # TODO: un-generalise latin american accents?
"Israeli": "Israeli",
"Vietnamese accent": "Vietnamese",
"Southern Irish": "Irish",
"Slight Vietnamese accent": "Vietnamese",
"Midwestern United States": "American",
"Vietnamese English": "Vietnamese",
"Vietnamese": "Vietnamese",
"": "Unknown",
}
def main():
# 1. Parse input arguments
parser = HfArgumentParser(DataTrainingArguments)
......@@ -155,9 +73,23 @@ def main():
"How would you describe your accent in English? (e.g. Italian, Glaswegian)"
]
accent_dataset = load_dataset("sanchit-gandhi/edacc_accents", split="train")
def format_dataset(batch):
batch["speaker_id"] = (
batch["Final-Participant_ID"].replace("EAEC", "EDACC").replace("P1", "-A").replace("P2", "-B")
)
return batch
accent_dataset = accent_dataset.map(format_dataset, remove_columns=["Final-Participant_ID"])
# 2. Clean accents for each speaker
linguistic_background_clean = {
participant: ACCENT_MAPPING[accent.strip()] for participant, accent in linguistic_background.items()
participant: accent.strip()
for participant, accent in zip(accent_dataset["speaker_id"], accent_dataset["English_Variety"])
}
linguistic_variety = {
participant: l1.strip() for participant, l1 in zip(accent_dataset["speaker_id"], accent_dataset["L1_Variety"])
}
# 3. Initialize dataset dict
......@@ -207,7 +139,7 @@ def main():
# add gender/l1 information
all_genders.append(re.search(gender_pat, gender_l1).group(1))
all_l1s.append(re.search(l1_pat, gender_l1).group(1))
all_l1s.append(linguistic_variety[speaker])
# read audio file if different from previous
if file != current_audio:
......@@ -238,7 +170,7 @@ def main():
"accent": all_normalized_accents,
"raw_accent": all_raw_accents,
"gender": all_genders,
"language": all_l1s,
"l1": all_l1s,
"audio": all_audio_paths,
}
).cast_column("audio", Audio())
......
......@@ -3,5 +3,5 @@
python prepare_edacc.py \
--dataset_dir "/fsx/sanchit/edacc/edacc_v1.0" \
--output_dir "/fsx/sanchit/edacc_processed" \
--hub_dataset_id "sanchit-gandhi/edacc" \
--push_to_hub True
--hub_dataset_id "sanchit-gandhi/edacc-normalized" \
--push_to_hub
#!/usr/bin/env bash
python run_prompt_creation.py \
--dataset_name "ylacombe/libritts_r_tags_and_text" \
accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \
--dataset_name "stable-speech/libritts-r-tags-and-text" \
--dataset_config_name "clean" \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--per_device_eval_batch_size 256 \
--attn_implementation "flash_attention_2" \
--per_device_eval_batch_size 64 \
--attn_implementation "sdpa" \
--dataloader_num_workers 4 \
--output_dir "./" \
--load_in_4bit \
--push_to_hub \
--hub_dataset_id "sanchit-gandhi/libritts_r_tags_and_text_generated"
--hub_dataset_id "stable-speech/libritts-r-tags-and-text-generated"
accelerate launch --multi_gpu --mixed_precision=fp16 --num_processes=8 run_prompt_creation.py \
--dataset_name "stable-speech/libritts-r-tags-and-text" \
--dataset_config_name "other" \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--per_device_eval_batch_size 64 \
--attn_implementation "sdpa" \
--dataloader_num_workers 4 \
--output_dir "./" \
--load_in_4bit \
--push_to_hub \
--hub_dataset_id "stable-speech/libritts-r-tags-and-text-generated"
#!/usr/bin/env bash
python run_prompt_creation.py \
--dataset_name "ylacombe/libritts_r_test_tag" \
--dataset_config_name "default" \
--dataset_name "ylacombe/libritts_r_tags_and_text" \
--dataset_config_name "clean" \
--dataset_split_name "dev.clean" \
--max_eval_samples 32 \
--model_name_or_path "TinyLlama/TinyLlama-1.1B-Chat-v1.0" \
--model_name_or_path "mistralai/Mistral-7B-Instruct-v0.2" \
--per_device_eval_batch_size 2 \
--attn_implementation "sdpa" \
--dataloader_num_workers 0 \
--output_dir "./" \
--load_in_4bit \
--push_to_hub \
--hub_dataset_id "sanchit-gandhi/libritts_r_test_tag_generated"
--load_in_4bit
......@@ -58,12 +58,180 @@ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 1600
return wav[random_offset : random_offset + sample_length]
def deterministic_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000) -> np.ndarray:
"""Take first `max_length` seconds from the input audio"""
sample_length = int(round(sample_rate * max_length))
if len(wav) <= sample_length:
return wav
return wav[0:sample_length]
# This list first defines the accent prefixes, which we use to strip the accent from CV
# e.g. England, southern accent, slight west-country expression -> England
# TODO(YL): update this with any CV test prefixes not present in the train set
STARTS_WITH = [
"Afrikaans",
"American",
"Australian",
"Bangladeshi",
"Canadian",
"Chinese",
"Dutch",
"Eastern European",
"European",
"England",
"English",
"German",
"Filipino",
"India",
"Irish" "Israeli",
"Italian",
"Japanese",
"Kenyan",
"Northern Irish",
"New Zealand",
"Nigerian",
"Malaysian",
"Russian",
"Scottish",
"Singaporean",
"Slavic",
"South African",
"Southern African",
"Swedish",
"Swiss",
"United States English",
"West Indies",
"french",
"polish",
"serbian",
]
# This dictionary is used to map the un-normalised accent names to normalised ones
# TODO(YL): update this with any CV test mappings not present in the train set
ACCENT_MAPPING = {
"British": "English",
"Canadian": "American",
# "Canadian": "American", TODO(SG): decide whether to normalize these to closely related accents
# "New zealand": "Australian",
"Northern irish": "Irish",
"New zealand": "Australian",
"Pakistani": "Indian",
"Mainstream u s english": "American",
"Southern british english": "English",
"Indian english": "Indian",
"Scottish english": "Scottish",
"Don't know": "Unknown",
"Nigerian english": "Nigerian",
"Kenyan english": "Kenyan",
"Ghanain english": "Ghanain",
"Jamaican english": "Jamaican",
"Indonesian english": "Indonesian",
"South african english": "South african",
"Irish english": "Irish",
"Latin": "Latin american",
"European": "Unknown", # Too general
"Eastern european": "Eastern european", # TODO(SG): keep for now, but maybe remove later as too general
"Bangladeshi": "Indian",
"England": "English",
"India": "Indian",
"Afrikaans": "South african",
"California": "American",
"Nepali": "Indian",
"New york city": "American",
"New jerseyan": "American",
"Northumbrian british english": "English",
"Nottinghamshire,east midlands": "English",
"Southern african": "South african",
"United states english": "American",
"West indies": "Jamaican",
"2nd language": "Unknown", # Too vague
"A savage texas gentleman": "American",
"A variety of texan english with some german influence that has undergone the cot-caught merger": "American",
"A'lo": "Unknown", # Unclear
"Academic southern english,england english": "English",
"Argentinian english": "Latin american",
"Austrian": "German",
"Bangladesh,india and south asia (india, pakistan, sri lanka)": "Indian",
"Brazillian accent": "Brazilian",
"British accent": "English",
"Caribbean canadian": "Unknown", # Specific combination not listed
"Colombian accent": "Latin american",
"Czech accent": "Czech",
"East african khoja": "Unknown", # Specific community
"East indian": "Indian",
"East london": "English",
"England,london,academic": "English",
"Filipino": "Unknown", # Unique blend
"Fluent,e sl,european": "Unknown", # Too vague
"Generic european": "Unknown", # Too vague
"Georgian english": "Unknown", # No direct match
"Ghanaian english accent,african regular reader": "Unknown", # Specific category not listed
"Haitian creole": "Unknown", # Unique blend
"Hispanic": "Latin american",
"Hispanic/latino": "Latin american",
"Hong kong english": "Chinese",
"Hong kong english,scottish english": "Chinese",
"Hunglish": "Hungarian",
"I think mine accent is influenced by indian accent ,yes please. ,india and south asia (india, pakistan, sri lanka)": "Indian",
"I was born in england and have lived in australia, canada and france.": "English",
"International english,united states english,australian english": "American",
"Israeli": "Unknown", # No direct match
"Israeli english": "Unknown", # No direct match
"Javanese,indonesian english,malaysian english": "Indonesian",
"Kazakhstan english": "Unknown", # No direct match
"Kiwi": "New zealand", # Could be generalised to Australian
"Latin america,united states english": "Latin american",
"Latin american accent": "Latin american",
"Latin english": "Unknown", # Too vague
"Latino": "Latin american",
"Latvian": "Latvian", # Note: added new
"Little latino,united states english,second language": "Latin american",
"Liverpool english,lancashire english,england english": "English",
"Liverpudlian english": "English",
"Malaysian english": "Malaysian", # Note: added new
"Mexican accent": "Latin american",
"Mid-atlantic united states english,philadelphia, pennsylvania, united states english,united states english,philadelphia style united states english": "American",
"Mid-atlantic,england english,united states english": "American",
"Midatlantic,england english": "American",
"Midwestern states (michigan),united states english": "American",
"Mild northern england english": "English",
"Minor french accent": "French",
"Mix of american and british ,native polish": "Polish",
"Mix of american and british accent": "Unknown", # Combination not clearly mapped
"Mostly american with some british and australian inflections": "Unknown", # Combination not clearly mapped
"My accent is influenced by the phones of all letters within a sentence.,southern african (south africa, zimbabwe, namibia)": "South african",
"New zealand english": "New Zealand English",
"Nigeria english": "Nigerian", # Note: added new
"Non native speaker from france": "French",
"Non-native": "Unknown", # Too vague
"Non-native,german accent": "German",
"North european english": "Unknown", # Too broad
"Norwegian": "Norwegian", # Note: added new
"Ontario,canadian english": "Canadian", # Note: added new
"Polish english": "Polish",
"Rhode island new england accent": "American",
"Singaporean english": "Singaporean", # Note: added new
"Slavic": "Eastern european",
"Slighty southern affected by decades in the midwest, 4 years in spain and germany, speak some german, spanish, polish. have lived in nine states.": "Unknown", # Complex blend
"South african": "South african",
"South atlantic (falkland islands, saint helena)": "Unknown", # Specific regions not listed
"South australia": "Australian",
"South indian": "Indian",
"Southern drawl": "American",
"Southern texas accent,united states english": "American",
"Southern united states,united states english": "American",
"Spanish bilingual": "Spanish",
"Spanish,foreign,non-native": "Spanish",
"Strong latvian accent": "Latvian",
"Swedish accent": "Swedish", # Note: added new
"Transnational englishes blend": "Unknown", # Too vague
"U.k. english": "English",
"Very slight russian accent,standard american english,boston influence": "American",
"Welsh english": "Welsh",
"West african": "Unknown", # No specific West African category
"West indian": "Unknown", # Caribbean, but no specific match
"Western europe": "Unknown", # Too broad
"With heavy cantonese accent": "Chinese",
}
......@@ -74,7 +242,10 @@ def preprocess_labels(label: str) -> str:
language_code = label.split("_")[-1]
label = LANGUAGES[language_code]
# VCTK labels for two words are concatenated into one (NewZeleand-> New Zealand)
label = re.sub(r"(\w)([A-Z])", r"\1 \2", label)
label = re.sub(r"(\w)([A-Z])", r"\1 \2", label).strip()
for prefix in STARTS_WITH:
if label.startswith(prefix):
label = prefix
# convert Whisper language code (polish) to capitalised (Polish)
label = label.capitalize()
if label in ACCENT_MAPPING:
......@@ -248,6 +419,52 @@ class ModelArguments:
default=True,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
attention_dropout: float = field(
default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
)
feat_proj_dropout: float = field(default=0.0, metadata={"help": "The dropout ratio for the projected features."})
hidden_dropout: float = field(
default=0.0,
metadata={
"help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
},
)
final_dropout: float = field(
default=0.0,
metadata={"help": "The dropout probability for the final projection layer."},
)
mask_time_prob: float = field(
default=0.05,
metadata={
"help": (
"Probability of each feature vector along the time axis to be chosen as the start of the vector "
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature "
"vectors will be masked along the time axis."
)
},
)
mask_time_length: int = field(
default=10,
metadata={"help": "Length of vector span to mask along the time axis."},
)
mask_feature_prob: float = field(
default=0.0,
metadata={
"help": (
"Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
" to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
" bins will be masked along the time axis."
)
},
)
mask_feature_length: int = field(
default=10,
metadata={"help": "Length of vector span to mask along the feature axis."},
)
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
def convert_dataset_str_to_list(
......@@ -467,9 +684,11 @@ def main():
if training_args.do_eval:
dataset_names_dict = convert_dataset_str_to_list(
data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name,
(
data_args.eval_dataset_config_name
if data_args.eval_dataset_config_name
else data_args.train_dataset_config_name
),
splits=data_args.eval_split_name,
label_column_names=data_args.eval_label_column_name,
)
......@@ -544,37 +763,39 @@ def main():
sampling_rate = feature_extractor.sampling_rate
model_input_name = feature_extractor.model_input_names[0]
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
def prepare_dataset(batch):
batch["length"] = len(batch["audio"]["array"])
batch["labels"] = preprocess_labels(batch["labels"])
return batch
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
raw_datasets = raw_datasets.map(
prepare_dataset,
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by labels",
desc="Computing audio length",
)
# filter training data with inputs < min_input_length
max_input_length = data_args.max_length_seconds * sampling_rate
min_input_length = data_args.min_length_seconds * sampling_rate
def is_audio_valid(audio):
return max_input_length > len(audio["array"]) > min_input_length
def is_audio_valid(input_length):
return input_length > min_input_length
raw_datasets = raw_datasets.filter(
is_audio_valid,
input_columns=["audio"],
input_columns=["length"],
num_proc=data_args.preprocessing_num_workers,
desc="Filtering by audio length",
)
# Prepare label mappings
raw_datasets = raw_datasets.map(
lambda label: {"labels": preprocess_labels(label)},
# filter training data with non-valid labels
def is_label_valid(label):
return label != "Unknown"
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
num_proc=data_args.preprocessing_num_workers,
desc="Pre-processing labels",
desc="Filtering by labels",
)
# Print a summary of the labels to the stddout (helps identify low-label classes that could be filtered)
......@@ -593,11 +814,11 @@ def main():
if freq < data_args.filter_threshold:
labels_to_remove.append(lab)
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
if len(labels_to_remove):
# filter training data with label freq below threshold
def is_label_valid(label):
return label not in labels_to_remove
raw_datasets = raw_datasets.filter(
is_label_valid,
input_columns=["labels"],
......@@ -606,7 +827,9 @@ def main():
)
# We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"]))
set_labels = set(raw_datasets["train"]["labels"])
if training_args.do_eval:
set_labels = set_labels.union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {}
for i, label in enumerate(set(set_labels)):
label2id[label] = str(i)
......@@ -614,9 +837,14 @@ def main():
def train_transforms(batch):
"""Apply train_transforms across a batch."""
audios = [audio["array"] for audio in batch["audio"]]
subsampled_wavs = []
for audio in batch["audio"]:
wav = deterministic_subsample(
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
)
subsampled_wavs.append(wav)
inputs = feature_extractor(
audios, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
subsampled_wavs, return_attention_mask=model_args.attention_mask, sampling_rate=sampling_rate
)
output_batch = {
model_input_name: inputs.get(model_input_name),
......@@ -654,6 +882,22 @@ def main():
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# adapt config with regularization
config.update(
{
"feat_proj_dropout": model_args.feat_proj_dropout,
"attention_dropout": model_args.attention_dropout,
"hidden_dropout": model_args.hidden_dropout,
"final_dropout": model_args.final_dropout,
"mask_time_prob": model_args.mask_time_prob,
"mask_time_length": model_args.mask_time_length,
"mask_feature_prob": model_args.mask_feature_prob,
"mask_feature_length": model_args.mask_feature_length,
"layerdrop": model_args.layerdrop,
"activation_dropout": model_args.activation_dropout,
}
)
model = AutoModelForAudioClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
......
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
from datasets import Audio, concatenate_datasets, load_dataset
from huggingface_hub import get_full_repo_name
from transformers import HfArgumentParser, WhisperTokenizerFast
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
default=None,
metadata={"help": "The name of the dataset to use (via the datasets library)."},
)
dataset_config_name: str = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
)
dataset_split_name: str = field(
default=None,
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
label_column_name: str = field(
default="labels",
metadata={"help": "The name of the dataset column containing the labels in the dataset. Defaults to 'label'"},
)
text_column_name: str = field(
default="text",
metadata={
"help": "The name of the dataset column containing the text transcriptions in the dataset. Defaults to 'text'"
},
)
speaker_column_name: str = field(
default="speaker_id",
metadata={
"help": "The name of the dataset column containing the speaker ids in the dataset. Defaults to 'speaker_id'"
},
)
dataset_cache_dir: str = field(
default=None,
metadata={"help": "Path to cache directory for saving and loading datasets"},
)
preprocessing_num_workers: int = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
batch_size: int = field(
default=500,
metadata={"help": "Number of examples per batch provided to the preprocessing function."},
)
download_only: bool = field(
default=False,
metadata={"help": "Whether to only do data download and skip pre-processing."},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
max_duration_in_seconds: float = field(
default=20.0,
metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
)
sampling_rate: int = field(
default=16_000,
metadata={
"help": "Sampling rate at which to resample the audio data. Should be set to the same sampling rate as the target model."
},
)
max_samples: int = field(
default=None,
metadata={
"help": "For debugging purposes, truncate the number of examples in the dataset to this value if set."
},
)
output_dir: str = field(
default=None,
metadata={
"help": "Where to save the processed dataset to disk. If unspecified, uses a 'pretty' version of the "
"original dataset name. E.g. 'facebook/voxpopuli' will be saved under 'voxpopuli'."
},
)
push_to_hub: bool = field(
default=False,
metadata={"help": "Whether or not to push the processed dataset to the Hub."},
)
seed: int = field(
default=0,
metadata={"help": "RNG seed for reproducibility. Used during the final shuffling of the combined dataset."},
)
def convert_dataset_str_to_list(
dataset_names,
dataset_config_names,
splits=None,
label_column_names=None,
text_column_names=None,
speaker_column_names=None,
dataset_samples=None,
default_split="train",
):
if isinstance(dataset_names, str):
dataset_names = dataset_names.split("+")
dataset_config_names = dataset_config_names.split("+")
splits = splits.split("+") if splits is not None else None
label_column_names = label_column_names.split("+") if label_column_names is not None else None
text_column_names = text_column_names.split("+") if text_column_names is not None else None
speaker_column_names = speaker_column_names.split("+") if speaker_column_names is not None else None
dataset_samples = dataset_samples.split("+") if dataset_samples is not None else None
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
if len(dataset_names) != len(dataset_config_names):
raise ValueError(
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(dataset_config_names)} configs."
)
if splits is not None and len(splits) != len(dataset_names):
raise ValueError(
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
)
if label_column_names is not None and len(label_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one label column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(label_column_names)} label column names."
)
if text_column_names is not None and len(text_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(text_column_names)} text column names."
)
if speaker_column_names is not None and len(speaker_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(speaker_column_names)} speaker column names."
)
if dataset_samples is not None:
if len(dataset_samples) != len(dataset_names):
raise ValueError(
f"Ensure one sample is passed for each dataset, got {len(dataset_names)} datasets and "
f"{len(dataset_samples)} samples."
)
dataset_samples = [float(ds_sample) for ds_sample in dataset_samples]
else:
dataset_samples = [None] * len(dataset_names)
label_column_names = (
label_column_names if label_column_names is not None else ["labels" for _ in range(len(dataset_names))]
)
text_column_names = (
text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
)
speaker_column_names = (
speaker_column_names if speaker_column_names is not None else ["speaker_id" for _ in range(len(dataset_names))]
)
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
dataset_names_dict = []
for i, ds_name in enumerate(dataset_names):
dataset_names_dict.append(
{
"name": ds_name,
"config": dataset_config_names[i],
"split": splits[i],
"label_column_name": label_column_names[i],
"text_column_name": text_column_names[i],
"speaker_column_name": speaker_column_names[i],
"samples": dataset_samples[i],
}
)
return dataset_names_dict
def main():
# 1. Parse input arguments
parser = HfArgumentParser(DataTrainingArguments)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
data_args = parser.parse_args_into_dataclasses()[0]
dataset_names_dict = convert_dataset_str_to_list(
data_args.dataset_name,
data_args.dataset_config_name,
splits=data_args.dataset_split_name,
label_column_names=data_args.label_column_name,
text_column_names=data_args.text_column_name,
speaker_column_names=data_args.speaker_column_name,
)
# load whisper tokenizer for normalisation
sampling_rate = data_args.sampling_rate
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny.en")
max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
batch_size = data_args.batch_size
preprocessing_num_workers = data_args.preprocessing_num_workers
all_vectorized_datasets = []
for dataset_dict in dataset_names_dict:
print(10 * "=", dataset_dict["name"], 10 * "=")
raw_datasets = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
num_proc=data_args.preprocessing_num_workers,
)
if data_args.download_only:
continue
features = raw_datasets.column_names
if dataset_dict["label_column_name"] not in features:
raise ValueError(
f"--label_column_name {dataset_dict['label_column_name']} not found in dataset '{dataset_dict['name']}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(features)}."
)
elif dataset_dict["label_column_name"] != "labels":
raw_datasets = raw_datasets.rename_column(dataset_dict["label_column_name"], "labels")
if dataset_dict["text_column_name"] not in features:
raise ValueError(
f"--text_column_name {dataset_dict['text_column_name']} not found in dataset '{dataset_dict['name']}'. "
"Make sure to set `--text_column_name` to the correct text column - one of "
f"{', '.join(features)}."
)
elif dataset_dict["text_column_name"] != "text":
raw_datasets = raw_datasets.rename_column(dataset_dict["text_column_name"], "text")
if dataset_dict["speaker_column_name"] not in features:
raise ValueError(
f"--speaker_column_name {dataset_dict['speaker_column_name']} not found in dataset '{dataset_dict['name']}'. "
"Make sure to set `--speaker_column_name` to the correct speaker id column - one of "
f"{', '.join(features)}."
)
elif dataset_dict["speaker_column_name"] != "speaker_id":
raw_datasets = raw_datasets.rename_column(dataset_dict["speaker_column_name"], "speaker_id")
raw_datasets = raw_datasets.remove_columns(
set(raw_datasets.features.keys()) - {"audio", "labels", "text", "speaker_id"}
)
if data_args.max_samples is not None:
raw_datasets = raw_datasets.select(range(data_args.max_samples))
raw_datasets = raw_datasets.cast_column(data_args.audio_column_name, Audio(sampling_rate=sampling_rate))
raw_datasets = raw_datasets.sort("speaker_id")
def filter_transcriptions(text):
normalized_text = tokenizer.normalize(text).strip()
return bool(normalized_text) and text.lower() != "ignore_time_segment_in_scoring"
raw_datasets = raw_datasets.filter(
filter_transcriptions, input_columns=["text"], desc="Filtering non-speech transcriptions"
)
def prepare_dataset(batch):
audio = [sample["array"] for sample in batch["audio"]]
input_lengths = [len(sample) for sample in audio]
concatenated_audio = []
concatenated_text = []
concatenated_speaker = []
concatenated_labels = []
audio_sample = audio[0]
text_sample = batch["text"][0]
label_sample = batch["labels"][0]
for idx in range(1, len(audio)):
prev_speaker = batch["speaker_id"][idx - 1]
speaker = batch["speaker_id"][idx]
if len(audio_sample) + input_lengths[idx] < max_input_length:
if speaker == prev_speaker:
# we have no information about whether the segments follow on sequentially
# so we just ensure the same speaker as we concatenate across files
audio_sample = np.append(audio_sample, audio[idx])
# extra spaces in the text transcription don't matter, since we only use it for the WER computation
text_sample += " " + batch["text"][idx]
else:
# segments do not follow sequentially, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_labels.append(label_sample)
concatenated_speaker.append(speaker)
audio_sample = audio[idx]
text_sample = batch["text"][idx]
label_sample = batch["labels"][idx]
else:
# concatenated audio exceeds max length, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_labels.append(label_sample)
concatenated_speaker.append(speaker)
audio_sample = audio[idx]
text_sample = batch["text"][idx]
label_sample = batch["labels"][idx]
batch["audio"] = [{"array": array, "sampling_rate": sampling_rate} for array in concatenated_audio]
batch["text"] = concatenated_text
batch["labels"] = concatenated_labels
batch["speaker_id"] = concatenated_speaker
return batch
raw_datasets = raw_datasets.map(
prepare_dataset,
batched=True,
batch_size=batch_size,
num_proc=preprocessing_num_workers,
desc="Concatenating dataset...",
)
pretty_name = dataset_dict["name"].split("/")[-1]
def postprocess_ids(speaker_id, idx):
formatted_idx = f"{pretty_name}-{speaker_id}-{idx}"
return {"id": formatted_idx}
raw_datasets = raw_datasets.map(
postprocess_ids,
input_columns=["speaker_id"],
with_indices=True,
desc="Setting sample idxs...",
num_proc=preprocessing_num_workers,
)
print(f"Final length {pretty_name}: ", len(raw_datasets))
# Re-format transcriptions and condition on prev as numpy arrays
raw_datasets = raw_datasets.with_format("np")
all_vectorized_datasets.append(raw_datasets)
all_vectorized_datasets = concatenate_datasets(all_vectorized_datasets)
dataset_features = all_vectorized_datasets.features.copy()
dataset_features["audio"] = Audio(sampling_rate=sampling_rate)
all_vectorized_datasets = all_vectorized_datasets.cast(
dataset_features, batch_size=batch_size, writer_batch_size=batch_size, num_proc=preprocessing_num_workers
)
all_vectorized_datasets = all_vectorized_datasets.shuffle(seed=data_args.seed)
all_vectorized_datasets.save_to_disk(data_args.output_dir)
repo_name = get_full_repo_name(Path(data_args.output_dir).absolute().name)
if data_args.push_to_hub:
all_vectorized_datasets.push_to_hub(repo_name, config_name="train", max_shard_size="1GB")
if __name__ == "__main__":
main()
......@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
from accelerate import Accelerator
from datasets import load_dataset, DatasetDict
from datasets import DatasetDict, load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
......@@ -81,13 +81,10 @@ class ModelArguments:
use_fast_tokenizer: Optional[bool] = field(
default=True, metadata={"help": "Use fast tokenizer for encoding/decoding input ids"}
)
token: str = field(
default=None,
token: Optional[bool] = field(
default=True,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
"help": "Whether or not to use an authentication token when loading/uploading from the Hugging Face Hub"
},
)
do_sample: Optional[bool] = field(default=True, metadata={"help": "Whether to use sampling mode for generation"})
......@@ -326,19 +323,22 @@ def main():
tokenizer.pad_token_id = tokenizer.bos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id
PROMPT = """ We have seven keywords that describe different attributes of an audio sample spoken by a given speaker: the speaker's gender, the speaker's accent, the amount of reverberation in the sample (high or low reverberation), the amount of noise in the sample (how clear or noisy), how monotone or animated the sample is, the speaker's pitch (high or low voice), the speaker's speed (how fast or slow the speaker is speaking).
Given these keywords, form a coherent sentence that summarises the seven attributes in a meaningful way. You can change the order of the keywords in the sentence and use common synonyms for these words, provided that the sentence summarises the attributes clearly. Keep the sentence simple - don't introduce additional information other than the keywords provided. Only return the generated sentence, not any other assistant remarks.
For example, given the following descriptors: 'female', 'Hungarian', 'slightly roomy sounding', 'fairly noisy', 'quite monotone', 'fairly low pitch', 'very slowly', a valid sentence would be: 'a woman with a deep voice speaking slowly and somewhat monotonously with a Hungarian accent in an echoey room with background noise'. Note how the seven attributes have been combined together in a simple sentence, with the ordering changed but no additional information added.
For the descriptors: {gender}, {accent}, {reverberation}, {noise}, {speech_monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:"""
SUBSET_PROMPT = """ We have six keywords that describe different attributes of an audio sample spoken by a given speaker: the speaker's gender, the amount of reverberation in the sample (high or low reverberation), the amount of noise in the sample (how clear or noisy), how monotone or animated the sample is, the speaker's pitch (high or low voice), the speaker's speed (how fast or slow the speaker is speaking).
Given these keywords, form a coherent sentence that summarises the six attributes in a meaningful way. You can change the order of the keywords in the sentence and use common synonyms for these words, provided that the sentence summarises the attributes clearly. Keep the sentence simple - don't introduce additional information other than the keywords provided. Only return the generated sentence, not any other assistant remarks.
For example, given the following descriptors: 'female', 'slightly roomy sounding', 'fairly noisy', 'quite monotone', 'fairly low pitch', 'very slowly', a valid sentence would be: 'a woman with a deep voice speaking slowly and somewhat monotonously in an echoey room with background noise'. Note how the six attributes have been combined together in a simple sentence, with the ordering changed but no additional information added.
For the descriptors: {gender}, {reverberation}, {noise}, {speech_monotony}, {pitch}, {speaking_rate}, the corresponding sentence is:"""
# TODO(SG): add accent keyword
PROMPT = (
"You will be given six descriptive keywords related to an audio sample of a person's speech. These keywords include:\n"
"1. The gender (e.g., male, female)\n"
"2. The level of reverberation (e.g., very roomy sounding, quite roomy sounding, slightly roomy sounding, moderate reverberation, slightly confined sounding, quite confined sounding, very confined sounding)\n"
"3. The amount of noise the sample (e.g., very noisy, quite noisy, slightly noisy, moderate ambient sound, slightly clear, quite clear, very clear)\n"
"4. The tone of the speaker's voice (e.g., very monotone, quite monotone, slightly monotone, moderate intonation, slightly expressive, quite expressive, very expressive)\n"
"5. The pace of the speaker's delivery (e.g., very slowly, quite slowly, slightly slowly, moderate speed, slightly fast, quite fast, very fast)\n"
"6. The pitch of the speaker's voice (e.g., very low pitch, quite low pitch, slightly low pitch, moderate pitch, slightly high pitch, quite high pitch, very high pitch)\n"
"Your task is to create a text description using these keywords that accurately describes the speech sample while ensuring the description remains grammatically correct and easy to understand. You can rearrange the keyword order as necessary, and substitute synonymous terms where appropriate. If the amount of noise is 'very noisy' and the level of reverberation is 'very roomy sounding', include the term 'very bad recording' in the description. Likewise, if the amount of noise is 'very clear' and the level of reverberation is 'very confined sounding', include the term 'very good recording' in the description. Otherwise, do not add extra details beyond what has been provided, and only return the generated description.\n"
"For example, given the following keywords: 'female', 'slightly roomy sounding', 'slightly noisy', 'quite monotone', 'slightly low pitch', 'very slowly', a valid description would be: 'a woman with a deep voice speaking slowly and somewhat monotonously in an echoey room with background noise'.\n"
"For the keywords: '[gender]', '[reverberation]', '[noise]', '[speech_monotony]', '[pitch]', '[speaking_rate]', the corresponding description is:"
)
def prepare_dataset(sample):
sample_prompt = SUBSET_PROMPT
sample_prompt = PROMPT
for key in EXPECTED_COLUMNS:
sample_prompt = sample_prompt.replace(f"[{key}]", sample[key])
sample_prompt = [{"role": "user", "content": sample_prompt}]
......@@ -379,6 +379,7 @@ def main():
all_generated_ids = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
generated_ids = generate_step(batch)
generated_ids = accelerator.gather_for_metrics(generated_ids)
all_generated_ids.extend(generated_ids.cpu())
def postprocess_dataset(sample, idx):
......@@ -396,12 +397,16 @@ def main():
with_indices=True,
)
accelerator.end_training()
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.output_dir)
if data_args.push_to_hub:
vectorized_datasets.push_to_hub(data_args.hub_dataset_id)
vectorized_datasets.push_to_hub(
data_args.hub_dataset_id,
config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else "default",
token=model_args.token,
)
accelerator.end_training()
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment